Skip to main content

mz_transform/
typecheck.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Check that the visible type of each query has not been changed
11
12use std::collections::BTreeMap;
13use std::fmt::Write;
14use std::sync::{Arc, Mutex};
15
16use itertools::Itertools;
17use mz_expr::explain::{HumanizedExplain, HumanizerMode};
18use mz_expr::{
19    AggregateExpr, ColumnOrder, Id, JoinImplementation, LocalId, MirRelationExpr, MirScalarExpr,
20    RECURSION_LIMIT, non_nullable_columns,
21};
22use mz_ore::soft_panic_or_log;
23use mz_ore::stack::{CheckedRecursion, RecursionGuard, RecursionLimitError};
24use mz_repr::adt::range::Range;
25use mz_repr::explain::{DummyHumanizer, ExprHumanizer};
26use mz_repr::{
27    ColumnName, Datum, ReprColumnType, ReprRelationType, ReprScalarBaseType, ReprScalarType,
28};
29
30/// Typechecking contexts as shared by various typechecking passes.
31///
32/// We use a `RefCell` to ensure that contexts are shared by multiple typechecker passes.
33/// Shared contexts help catch consistency issues.
34pub type SharedTypecheckingContext = Arc<Mutex<Context>>;
35
36/// Generates an empty context
37pub fn empty_typechecking_context() -> SharedTypecheckingContext {
38    Arc::new(Mutex::new(BTreeMap::new()))
39}
40
41/// The possible forms of inconsistency/errors discovered during typechecking.
42///
43/// Every variant has a `source` field identifying the MIR term that is home
44/// to the error (though not necessarily the root cause of the error).
45#[derive(Clone, Debug)]
46pub enum TypeError<'a> {
47    /// Unbound identifiers (local or global)
48    Unbound {
49        /// Expression with the bug
50        source: &'a MirRelationExpr,
51        /// The (unbound) identifier referenced
52        id: Id,
53        /// The type `id` was expected to have
54        typ: ReprRelationType,
55    },
56    /// Dereference of a non-existent column
57    NoSuchColumn {
58        /// Expression with the bug
59        source: &'a MirRelationExpr,
60        /// Scalar expression that references an invalid column
61        expr: &'a MirScalarExpr,
62        /// The invalid column referenced
63        col: usize,
64    },
65    /// A single column type does not match
66    MismatchColumn {
67        /// Expression with the bug
68        source: &'a MirRelationExpr,
69        /// The column type we found (`sub` type)
70        got: ReprColumnType,
71        /// The column type we expected (`sup` type)
72        expected: ReprColumnType,
73        /// The difference between these types
74        diffs: Vec<ReprColumnTypeDifference>,
75        /// An explanatory message
76        message: String,
77    },
78    /// Relation column types do not match
79    MismatchColumns {
80        /// Expression with the bug
81        source: &'a MirRelationExpr,
82        /// The column types we found (`sub` type)
83        got: Vec<ReprColumnType>,
84        /// The column types we expected (`sup` type)
85        expected: Vec<ReprColumnType>,
86        /// The difference between these types
87        diffs: Vec<ReprRelationTypeDifference>,
88        /// An explanatory message
89        message: String,
90    },
91    /// A constant row does not have the correct type
92    BadConstantRowLen {
93        /// Expression with the bug
94        source: &'a MirRelationExpr,
95        /// A constant row
96        got: usize,
97        /// The expected type (which that row does not have)
98        expected: Vec<ReprColumnType>,
99    },
100    /// A constant row does not have the correct type
101    BadConstantRow {
102        /// Expression with the bug
103        source: &'a MirRelationExpr,
104        /// The columns that mismatched, with a debug rendering of the datum we got and the expected type
105        mismatches: Vec<(usize, DatumTypeDifference)>,
106        /// The expected type (which that row does not have)
107        expected: Vec<ReprColumnType>,
108        // TODO(mgree) with a good way to get the type of a Datum, we could give a detailed diff here
109        // this is difficult for empty structures where we may not know the element type
110    },
111    /// Projection of a non-existent column
112    BadProject {
113        /// Expression with the bug
114        source: &'a MirRelationExpr,
115        /// The column projected
116        got: Vec<usize>,
117        /// The input columns (which don't have that column)
118        input_type: Vec<ReprColumnType>,
119    },
120    /// An equivalence class in a join was malformed
121    BadJoinEquivalence {
122        /// Expression with the bug
123        source: &'a MirRelationExpr,
124        /// The join equivalences
125        got: Vec<ReprColumnType>,
126        /// The problem with the join equivalences
127        message: String,
128    },
129    /// TopK grouping by non-existent column
130    BadTopKGroupKey {
131        /// Expression with the bug
132        source: &'a MirRelationExpr,
133        /// The bad column reference in the group key
134        k: usize,
135        /// The input columns (which don't have that column)
136        input_type: Vec<ReprColumnType>,
137    },
138    /// TopK ordering by non-existent column
139    BadTopKOrdering {
140        /// Expression with the bug
141        source: &'a MirRelationExpr,
142        /// The ordering used
143        order: ColumnOrder,
144        /// The input columns (which don't work for that ordering)
145        input_type: Vec<ReprColumnType>,
146    },
147    /// LetRec bindings are malformed
148    BadLetRecBindings {
149        /// Expression with the bug
150        source: &'a MirRelationExpr,
151    },
152    /// Local identifiers are shadowed
153    Shadowing {
154        /// Expression with the bug
155        source: &'a MirRelationExpr,
156        /// The id that was shadowed
157        id: Id,
158    },
159    /// Recursion depth exceeded
160    Recursion {
161        /// The error that aborted recursion
162        error: RecursionLimitError,
163    },
164    /// A dummy value was found
165    DisallowedDummy {
166        /// The expression with the dummy value
167        source: &'a MirRelationExpr,
168    },
169}
170
171impl<'a> From<RecursionLimitError> for TypeError<'a> {
172    fn from(error: RecursionLimitError) -> Self {
173        TypeError::Recursion { error }
174    }
175}
176
177type Context = BTreeMap<Id, Vec<ReprColumnType>>;
178
179/// Characterizes differences between relation types
180///
181/// Each constructor indicates a reason why some type `sub` was not a subtype of another type `sup`
182#[derive(Clone, Debug, Hash)]
183pub enum ReprRelationTypeDifference {
184    /// `sub` and `sup` don't have the same number of columns
185    Length {
186        /// Length of `sub`
187        len_sub: usize,
188        /// Length of `sup`
189        len_sup: usize,
190    },
191    /// `sub` and `sup` differ at the indicated column
192    Column {
193        /// The column at which `sub` and `sup` differ
194        col: usize,
195        /// The difference between `sub` and `sup`
196        diff: ReprColumnTypeDifference,
197    },
198}
199
200/// Characterizes differences between individual column types
201///
202/// Each constructor indicates a reason why some type `sub` was not a subtype of another type `sup`
203/// There may be multiple reasons, e.g., `sub` may be missing fields and have fields of different types
204#[derive(Clone, Debug, Hash)]
205pub enum ReprColumnTypeDifference {
206    /// The `ReprScalarBaseType` of `sub` doesn't match that of `sup`
207    NotSubtype {
208        /// Would-be subtype
209        sub: ReprScalarType,
210        /// Would-be supertype
211        sup: ReprScalarType,
212    },
213    /// `sub` was nullable but `sup` was not
214    Nullability {
215        /// Would-be subtype
216        sub: ReprColumnType,
217        /// Would-be supertype
218        sup: ReprColumnType,
219    },
220    /// Both `sub` and `sup` are a list, map, array, or range, but `sub`'s element type differed from `sup`s
221    ElementType {
222        /// The type constructor (list, array, etc.)
223        ctor: String,
224        /// The difference in the element type
225        element_type: Box<ReprColumnTypeDifference>,
226    },
227    /// `sub` and `sup` are both records, but `sub` is missing fields present in `sup`
228    RecordMissingFields {
229        /// The missing fields
230        missing: Vec<ColumnName>,
231    },
232    /// `sub` and `sup` are both records, but some fields in `sub` are not subtypes of fields in `sup`
233    RecordFields {
234        /// The differences, by field
235        fields: Vec<ReprColumnTypeDifference>,
236    },
237}
238
239impl ReprRelationTypeDifference {
240    /// Returns the same type difference, but ignoring nullability
241    ///
242    /// Returns `None` when _all_ of the differences are due to nullability
243    pub fn ignore_nullability(self) -> Option<Self> {
244        use ReprRelationTypeDifference::*;
245
246        match self {
247            Length { .. } => Some(self),
248            Column { col, diff } => diff.ignore_nullability().map(|diff| Column { col, diff }),
249        }
250    }
251}
252
253impl ReprColumnTypeDifference {
254    /// Returns the same type difference, but ignoring nullability
255    ///
256    /// Returns `None` when _all_ of the differences are due to nullability
257    pub fn ignore_nullability(self) -> Option<Self> {
258        use ReprColumnTypeDifference::*;
259
260        match self {
261            Nullability { .. } => None,
262            NotSubtype { .. } | RecordMissingFields { .. } => Some(self),
263            ElementType { ctor, element_type } => {
264                element_type
265                    .ignore_nullability()
266                    .map(|element_type| ElementType {
267                        ctor,
268                        element_type: Box::new(element_type),
269                    })
270            }
271            RecordFields { fields } => {
272                let fields = fields
273                    .into_iter()
274                    .flat_map(|diff| diff.ignore_nullability())
275                    .collect::<Vec<_>>();
276
277                if fields.is_empty() {
278                    None
279                } else {
280                    Some(RecordFields { fields })
281                }
282            }
283        }
284    }
285}
286
287/// Returns a list of differences that make `sub` not a subtype of `sup`
288///
289/// This function returns an empty list when `sub` is a subtype of `sup`
290pub fn relation_subtype_difference(
291    sub: &[ReprColumnType],
292    sup: &[ReprColumnType],
293) -> Vec<ReprRelationTypeDifference> {
294    let mut diffs = Vec::new();
295
296    if sub.len() != sup.len() {
297        diffs.push(ReprRelationTypeDifference::Length {
298            len_sub: sub.len(),
299            len_sup: sup.len(),
300        });
301
302        // TODO(mgree) we could do an edit-distance computation to report more errors
303        return diffs;
304    }
305
306    diffs.extend(
307        sub.iter()
308            .zip_eq(sup.iter())
309            .enumerate()
310            .flat_map(|(col, (sub_ty, sup_ty))| {
311                column_subtype_difference(sub_ty, sup_ty)
312                    .into_iter()
313                    .map(move |diff| ReprRelationTypeDifference::Column { col, diff })
314            }),
315    );
316
317    diffs
318}
319
320/// Returns a list of differences that make `sub` not a subtype of `sup`
321///
322/// This function returns an empty list when `sub` is a subtype of `sup`
323pub fn column_subtype_difference(
324    sub: &ReprColumnType,
325    sup: &ReprColumnType,
326) -> Vec<ReprColumnTypeDifference> {
327    let mut diffs = scalar_subtype_difference(&sub.scalar_type, &sup.scalar_type);
328
329    if sub.nullable && !sup.nullable {
330        diffs.push(ReprColumnTypeDifference::Nullability {
331            sub: sub.clone(),
332            sup: sup.clone(),
333        });
334    }
335
336    diffs
337}
338
339/// Returns a list of differences that make `sub` not a subtype of `sup`
340///
341/// This function returns an empty list when `sub` is a subtype of `sup`
342pub fn scalar_subtype_difference(
343    sub: &ReprScalarType,
344    sup: &ReprScalarType,
345) -> Vec<ReprColumnTypeDifference> {
346    use ReprScalarType::*;
347
348    let mut diffs = Vec::new();
349
350    match (sub, sup) {
351        (
352            List {
353                element_type: sub_elt,
354                ..
355            },
356            List {
357                element_type: sup_elt,
358                ..
359            },
360        )
361        | (
362            Map {
363                value_type: sub_elt,
364                ..
365            },
366            Map {
367                value_type: sup_elt,
368                ..
369            },
370        )
371        | (
372            Range {
373                element_type: sub_elt,
374                ..
375            },
376            Range {
377                element_type: sup_elt,
378                ..
379            },
380        )
381        | (Array(sub_elt), Array(sup_elt)) => {
382            let ctor = format!("{:?}", ReprScalarBaseType::from(sub));
383            diffs.extend(
384                scalar_subtype_difference(sub_elt, sup_elt)
385                    .into_iter()
386                    .map(|diff| ReprColumnTypeDifference::ElementType {
387                        ctor: ctor.clone(),
388                        element_type: Box::new(diff),
389                    }),
390            );
391        }
392        (
393            Record {
394                fields: sub_fields, ..
395            },
396            Record {
397                fields: sup_fields, ..
398            },
399        ) => {
400            if sub_fields.len() != sup_fields.len() {
401                diffs.push(ReprColumnTypeDifference::NotSubtype {
402                    sub: sub.clone(),
403                    sup: sup.clone(),
404                });
405                return diffs;
406            }
407
408            for (sub_ty, sup_ty) in sub_fields.iter().zip_eq(sup_fields.iter()) {
409                diffs.extend(column_subtype_difference(sub_ty, sup_ty));
410            }
411        }
412        (_, _) => {
413            if ReprScalarBaseType::from(sub) != ReprScalarBaseType::from(sup) {
414                diffs.push(ReprColumnTypeDifference::NotSubtype {
415                    sub: sub.clone(),
416                    sup: sup.clone(),
417                })
418            }
419        }
420    };
421
422    diffs
423}
424
425/// Unions `other` into `typ`, returning a list of differences on failure
426///
427/// This function returns an empty list when `typ` and `other` are a union
428pub fn scalar_union(
429    typ: &mut ReprScalarType,
430    other: &ReprScalarType,
431) -> Vec<ReprColumnTypeDifference> {
432    use ReprScalarType::*;
433
434    let mut diffs = Vec::new();
435
436    // precomputing to appease the borrow checker
437    let ctor = ReprScalarBaseType::from(&*typ);
438    match (typ, other) {
439        (
440            List {
441                element_type: typ_elt,
442            },
443            List {
444                element_type: other_elt,
445            },
446        )
447        | (
448            Map {
449                value_type: typ_elt,
450            },
451            Map {
452                value_type: other_elt,
453            },
454        )
455        | (
456            Range {
457                element_type: typ_elt,
458            },
459            Range {
460                element_type: other_elt,
461            },
462        )
463        | (Array(typ_elt), Array(other_elt)) => {
464            let res = scalar_union(typ_elt.as_mut(), other_elt.as_ref());
465            diffs.extend(
466                res.into_iter()
467                    .map(|diff| ReprColumnTypeDifference::ElementType {
468                        ctor: format!("{ctor:?}"),
469                        element_type: Box::new(diff),
470                    }),
471            );
472        }
473        (
474            Record { fields: typ_fields },
475            Record {
476                fields: other_fields,
477            },
478        ) => {
479            if typ_fields.len() != other_fields.len() {
480                diffs.push(ReprColumnTypeDifference::NotSubtype {
481                    sub: ReprScalarType::Record {
482                        fields: typ_fields.clone(),
483                    },
484                    sup: other.clone(),
485                });
486                return diffs;
487            }
488
489            for (typ_ty, other_ty) in typ_fields.iter_mut().zip_eq(other_fields.iter()) {
490                diffs.extend(column_union(typ_ty, other_ty));
491            }
492        }
493        (typ, _) => {
494            if ctor != ReprScalarBaseType::from(other) {
495                diffs.push(ReprColumnTypeDifference::NotSubtype {
496                    sub: typ.clone(),
497                    sup: other.clone(),
498                })
499            }
500        }
501    };
502
503    diffs
504}
505
506/// Unions `other` into `typ`, returning a list of differences on failure
507///
508/// This function returns an empty list when `typ` and `other` are a union
509pub fn column_union(
510    typ: &mut ReprColumnType,
511    other: &ReprColumnType,
512) -> Vec<ReprColumnTypeDifference> {
513    let diffs = scalar_union(&mut typ.scalar_type, &other.scalar_type);
514
515    if diffs.is_empty() {
516        typ.nullable |= other.nullable;
517    }
518
519    diffs
520}
521
522/// Returns true when it is safe to treat a `sub` row as an `sup` row
523///
524/// In particular, the core types must be equal, and if a column in `sup` is nullable, that column should also be nullable in `sub`
525/// Conversely, it is okay to treat a known non-nullable column as nullable: `sub` may be nullable when `sup` is not
526pub fn is_subtype_of(sub: &[ReprColumnType], sup: &[ReprColumnType]) -> bool {
527    if sub.len() != sup.len() {
528        return false;
529    }
530
531    sub.iter().zip_eq(sup.iter()).all(|(got, known)| {
532        (!known.nullable || got.nullable) && got.scalar_type == known.scalar_type
533    })
534}
535
536/// Characterizes how a Datum differs from a ReprColumnType
537#[derive(Clone, Debug)]
538pub enum DatumTypeDifference {
539    /// Datum was null, but expected a non-null value (i.e., a ReprScalarType)
540    Null {
541        /// The representation scalar type that we expected
542        expected: ReprScalarType,
543    },
544    /// Datum's kind doesn't match the expected ReprScalarType
545    Mismatch {
546        /// The debug rendering of the Datum that we got (we don't store the Datum itself to avoid borrowing issues)
547        /// This debug rendering includes the DatumKind.
548        got_debug: String,
549        /// The representation scalar type that we expected
550        expected: ReprScalarType,
551    },
552    /// Datum's dimensions didn't match the type's dimensions
553    MismatchDimensions {
554        /// The type constructor (list, array, int2vector, etc.)
555        ctor: String,
556        /// The dimension of the datum that we got
557        got: usize,
558        /// The dimension of the type that we expected
559        expected: usize,
560    },
561    /// There was a nested difference in the element type of some container type
562    ElementType {
563        /// The type constructor (list, array, int2vector, etc.)
564        ctor: String,
565        /// The difference in the element type
566        element_type: Box<DatumTypeDifference>,
567    },
568}
569
570/// Computes the difference between this datum and the specified (representation) column type.
571///
572/// See [`Datum<'a>::is_instance_of`] for just getting a yes/no answer.
573///
574/// See [`Datum<'a>::is_instance_of_sql`] for comparing `Datum`s to `SqlColumnType`s.
575fn datum_difference_with_column_type(
576    datum: &Datum<'_>,
577    column_type: &ReprColumnType,
578) -> Result<(), DatumTypeDifference> {
579    fn difference_with_scalar_type(
580        datum: &Datum<'_>,
581        scalar_type: &ReprScalarType,
582    ) -> Result<(), DatumTypeDifference> {
583        fn mismatch(got: &Datum<'_>, expected: &ReprScalarType) -> Result<(), DatumTypeDifference> {
584            Err(DatumTypeDifference::Mismatch {
585                // this will be redacted as appropriate
586                got_debug: format!("{got:?}"),
587                expected: expected.clone(),
588            })
589        }
590
591        if let ReprScalarType::Jsonb = scalar_type {
592            // json type checking
593            match datum {
594                Datum::Dummy => Ok(()), // allow dummys (unlike is_instance_of)
595                Datum::Null => Err(DatumTypeDifference::Null {
596                    expected: ReprScalarType::Jsonb,
597                }),
598                Datum::JsonNull
599                | Datum::False
600                | Datum::True
601                | Datum::Numeric(_)
602                | Datum::String(_) => Ok(()),
603                Datum::List(list) => {
604                    for elem in list.iter() {
605                        difference_with_scalar_type(&elem, scalar_type)?;
606                    }
607                    Ok(())
608                }
609                Datum::Map(dict) => {
610                    for (_, val) in dict.iter() {
611                        difference_with_scalar_type(&val, scalar_type)?;
612                    }
613                    Ok(())
614                }
615                _ => mismatch(datum, scalar_type),
616            }
617        } else {
618            fn element_type_difference(
619                ctor: &str,
620                element_type: DatumTypeDifference,
621            ) -> DatumTypeDifference {
622                DatumTypeDifference::ElementType {
623                    ctor: ctor.to_string(),
624                    element_type: Box::new(element_type),
625                }
626            }
627            match (datum, scalar_type) {
628                (Datum::Dummy, _) => Ok(()), // allow dummys (unlike is_instance_of)
629                (Datum::Null, _) => Err(DatumTypeDifference::Null {
630                    expected: scalar_type.clone(),
631                }),
632                (Datum::False, ReprScalarType::Bool) => Ok(()),
633                (Datum::False, _) => mismatch(datum, scalar_type),
634                (Datum::True, ReprScalarType::Bool) => Ok(()),
635                (Datum::True, _) => mismatch(datum, scalar_type),
636                (Datum::Int16(_), ReprScalarType::Int16) => Ok(()),
637                (Datum::Int16(_), _) => mismatch(datum, scalar_type),
638                (Datum::Int32(_), ReprScalarType::Int32) => Ok(()),
639                (Datum::Int32(_), _) => mismatch(datum, scalar_type),
640                (Datum::Int64(_), ReprScalarType::Int64) => Ok(()),
641                (Datum::Int64(_), _) => mismatch(datum, scalar_type),
642                (Datum::UInt8(_), ReprScalarType::UInt8) => Ok(()),
643                (Datum::UInt8(_), _) => mismatch(datum, scalar_type),
644                (Datum::UInt16(_), ReprScalarType::UInt16) => Ok(()),
645                (Datum::UInt16(_), _) => mismatch(datum, scalar_type),
646                (Datum::UInt32(_), ReprScalarType::UInt32) => Ok(()),
647                (Datum::UInt32(_), _) => mismatch(datum, scalar_type),
648                (Datum::UInt64(_), ReprScalarType::UInt64) => Ok(()),
649                (Datum::UInt64(_), _) => mismatch(datum, scalar_type),
650                (Datum::Float32(_), ReprScalarType::Float32) => Ok(()),
651                (Datum::Float32(_), _) => mismatch(datum, scalar_type),
652                (Datum::Float64(_), ReprScalarType::Float64) => Ok(()),
653                (Datum::Float64(_), _) => mismatch(datum, scalar_type),
654                (Datum::Date(_), ReprScalarType::Date) => Ok(()),
655                (Datum::Date(_), _) => mismatch(datum, scalar_type),
656                (Datum::Time(_), ReprScalarType::Time) => Ok(()),
657                (Datum::Time(_), _) => mismatch(datum, scalar_type),
658                (Datum::Timestamp(_), ReprScalarType::Timestamp { .. }) => Ok(()),
659                (Datum::Timestamp(_), _) => mismatch(datum, scalar_type),
660                (Datum::TimestampTz(_), ReprScalarType::TimestampTz { .. }) => Ok(()),
661                (Datum::TimestampTz(_), _) => mismatch(datum, scalar_type),
662                (Datum::Interval(_), ReprScalarType::Interval) => Ok(()),
663                (Datum::Interval(_), _) => mismatch(datum, scalar_type),
664                (Datum::Bytes(_), ReprScalarType::Bytes) => Ok(()),
665                (Datum::Bytes(_), _) => mismatch(datum, scalar_type),
666                (Datum::String(_), ReprScalarType::String) => Ok(()),
667                (Datum::String(_), _) => mismatch(datum, scalar_type),
668                (Datum::Uuid(_), ReprScalarType::Uuid) => Ok(()),
669                (Datum::Uuid(_), _) => mismatch(datum, scalar_type),
670                (Datum::Array(array), ReprScalarType::Array(t)) => {
671                    for e in array.elements().iter() {
672                        if let Datum::Null = e {
673                            continue;
674                        }
675
676                        difference_with_scalar_type(&e, t)
677                            .map_err(|e| element_type_difference("array", e))?;
678                    }
679                    Ok(())
680                }
681                (Datum::Array(array), ReprScalarType::Int2Vector) => {
682                    if !array.has_int2vector_dims() {
683                        // Int2Vector is 1-D, but empty arrays use 0 dimensions (PostgreSQL convention,
684                        // the error message just mentions 1 dimension, but 0 dimensions are allowed when empty.
685                        return Err(DatumTypeDifference::MismatchDimensions {
686                            ctor: "int2vector".to_string(),
687                            got: array.dims().len(),
688                            expected: 1,
689                        });
690                    }
691
692                    for e in array.elements().iter() {
693                        difference_with_scalar_type(&e, &ReprScalarType::Int16)
694                            .map_err(|e| element_type_difference("int2vector", e))?;
695                    }
696
697                    Ok(())
698                }
699                (Datum::Array(_), _) => mismatch(datum, scalar_type),
700                (Datum::List(list), ReprScalarType::List { element_type, .. }) => {
701                    for e in list.iter() {
702                        if let Datum::Null = e {
703                            continue;
704                        }
705
706                        difference_with_scalar_type(&e, element_type)
707                            .map_err(|e| element_type_difference("list", e))?;
708                    }
709                    Ok(())
710                }
711                (Datum::List(list), ReprScalarType::Record { fields, .. }) => {
712                    let len = list.iter().count();
713                    if len != fields.len() {
714                        return Err(DatumTypeDifference::MismatchDimensions {
715                            ctor: "record".to_string(),
716                            got: len,
717                            expected: fields.len(),
718                        });
719                    }
720
721                    for (e, t) in list.iter().zip_eq(fields) {
722                        if let Datum::Null = e {
723                            if t.nullable {
724                                continue;
725                            } else {
726                                return Err(DatumTypeDifference::Null {
727                                    expected: t.scalar_type.clone(),
728                                });
729                            }
730                        }
731
732                        difference_with_scalar_type(&e, &t.scalar_type)
733                            .map_err(|e| element_type_difference("record", e))?;
734                    }
735                    Ok(())
736                }
737                (Datum::List(_), _) => mismatch(datum, scalar_type),
738                (Datum::Map(map), ReprScalarType::Map { value_type, .. }) => {
739                    for (_, v) in map.iter() {
740                        if let Datum::Null = v {
741                            continue;
742                        }
743
744                        difference_with_scalar_type(&v, value_type)
745                            .map_err(|e| element_type_difference("map", e))?;
746                    }
747                    Ok(())
748                }
749                (Datum::Map(_), _) => mismatch(datum, scalar_type),
750                (Datum::JsonNull, _) => mismatch(datum, scalar_type),
751                (Datum::Numeric(_), ReprScalarType::Numeric) => Ok(()),
752                (Datum::Numeric(_), _) => mismatch(datum, scalar_type),
753                (Datum::MzTimestamp(_), ReprScalarType::MzTimestamp) => Ok(()),
754                (Datum::MzTimestamp(_), _) => mismatch(datum, scalar_type),
755                (Datum::Range(Range { inner }), ReprScalarType::Range { element_type }) => {
756                    match inner {
757                        None => Ok(()),
758                        Some(inner) => {
759                            if let Some(b) = inner.lower.bound {
760                                difference_with_scalar_type(&b.datum(), element_type)
761                                    .map_err(|e| element_type_difference("range", e))?;
762                            }
763                            if let Some(b) = inner.upper.bound {
764                                difference_with_scalar_type(&b.datum(), element_type)
765                                    .map_err(|e| element_type_difference("range", e))?;
766                            }
767                            Ok(())
768                        }
769                    }
770                }
771                (Datum::Range(_), _) => mismatch(datum, scalar_type),
772                (Datum::MzAclItem(_), ReprScalarType::MzAclItem) => Ok(()),
773                (Datum::MzAclItem(_), _) => mismatch(datum, scalar_type),
774                (Datum::AclItem(_), ReprScalarType::AclItem) => Ok(()),
775                (Datum::AclItem(_), _) => mismatch(datum, scalar_type),
776            }
777        }
778    }
779    if column_type.nullable {
780        if let Datum::Null = datum {
781            return Ok(());
782        }
783    }
784    difference_with_scalar_type(datum, &column_type.scalar_type)
785}
786
787fn row_difference_with_column_types<'a>(
788    source: &'a MirRelationExpr,
789    datums: &Vec<Datum<'_>>,
790    column_types: &[ReprColumnType],
791) -> Result<(), TypeError<'a>> {
792    // correct length
793    if datums.len() != column_types.len() {
794        return Err(TypeError::BadConstantRowLen {
795            source,
796            got: datums.len(),
797            expected: column_types.to_vec(),
798        });
799    }
800
801    // correct types
802    let mut mismatches = Vec::new();
803    for (i, (d, ty)) in datums.iter().zip_eq(column_types.iter()).enumerate() {
804        if let Err(e) = datum_difference_with_column_type(d, ty) {
805            mismatches.push((i, e));
806        }
807    }
808    if !mismatches.is_empty() {
809        return Err(TypeError::BadConstantRow {
810            source,
811            mismatches,
812            expected: column_types.to_vec(),
813        });
814    }
815
816    Ok(())
817}
818/// Check that the visible type of each query has not been changed
819#[derive(Debug)]
820pub struct Typecheck {
821    /// The known types of the queries so far
822    ctx: SharedTypecheckingContext,
823    /// Whether or not this is the first run of the transform
824    disallow_new_globals: bool,
825    /// Whether or not to be strict about join equivalences having the same nullability
826    strict_join_equivalences: bool,
827    /// Whether or not to disallow dummy values
828    disallow_dummy: bool,
829    /// Recursion guard for checked recursion
830    recursion_guard: RecursionGuard,
831}
832
833impl CheckedRecursion for Typecheck {
834    fn recursion_guard(&self) -> &RecursionGuard {
835        &self.recursion_guard
836    }
837}
838
839impl Typecheck {
840    /// Creates a typechecking consistency checking pass using a given shared context
841    pub fn new(ctx: SharedTypecheckingContext) -> Self {
842        Self {
843            ctx,
844            disallow_new_globals: false,
845            strict_join_equivalences: false,
846            disallow_dummy: false,
847            recursion_guard: RecursionGuard::with_limit(RECURSION_LIMIT),
848        }
849    }
850
851    /// New non-transient global IDs will be treated as an error
852    ///
853    /// Only turn this on after the context has been appropriately populated by, e.g., an earlier run
854    pub fn disallow_new_globals(mut self) -> Self {
855        self.disallow_new_globals = true;
856        self
857    }
858
859    /// Equivalence classes in joins must not only agree on scalar type, but also on nullability
860    ///
861    /// Only turn this on before `JoinImplementation`
862    pub fn strict_join_equivalences(mut self) -> Self {
863        self.strict_join_equivalences = true;
864
865        self
866    }
867
868    /// Disallow dummy values
869    pub fn disallow_dummy(mut self) -> Self {
870        self.disallow_dummy = true;
871        self
872    }
873
874    /// Returns the type of a relation expression or a type error.
875    ///
876    /// This function is careful to check validity, not just find out the type.
877    ///
878    /// It should be linear in the size of the AST.
879    ///
880    /// ??? should we also compute keys and return a `ReprRelationType`?
881    ///   ggevay: Checking keys would have the same problem as checking nullability: key inference
882    ///   is very heuristic (even more so than nullability inference), so it's almost impossible to
883    ///   reliably keep it stable across transformations.
884    pub fn typecheck<'a>(
885        &self,
886        expr: &'a MirRelationExpr,
887        ctx: &Context,
888    ) -> Result<Vec<ReprColumnType>, TypeError<'a>> {
889        use MirRelationExpr::*;
890
891        self.checked_recur(|tc| match expr {
892            Constant { typ, rows } => {
893                if let Ok(rows) = rows {
894                    for (row, _id) in rows {
895                        let datums = row.unpack();
896
897                        let col_types = typ
898                            .column_types
899                            .iter()
900                            .cloned()
901                            .collect_vec();
902                        row_difference_with_column_types(
903                            expr, &datums, &col_types,
904                        )?;
905
906                        if self.disallow_dummy
907                            && datums.iter().any(|d| d == &mz_repr::Datum::Dummy)
908                        {
909                            return Err(TypeError::DisallowedDummy {
910                                source: expr,
911                            });
912                        }
913                    }
914                }
915
916                Ok(typ.column_types.iter().cloned().collect_vec())
917            }
918            Get { typ, id, .. } => {
919                if let Id::Global(_global_id) = id {
920                    if !ctx.contains_key(id) {
921                        // TODO(mgree) pass QueryContext through to check these types
922                        return Ok(typ.column_types.iter().cloned().collect_vec());
923                    }
924                }
925
926                let ctx_typ = ctx.get(id).ok_or_else(|| TypeError::Unbound {
927                    source: expr,
928                    id: id.clone(),
929                    typ: typ.clone(),
930                })?;
931
932                let column_types = typ.column_types.iter().cloned().collect_vec();
933
934                // covariant: the ascribed type must be a subtype of the actual type in the context
935                let diffs = relation_subtype_difference(&column_types, ctx_typ)
936                    .into_iter()
937                    .flat_map(|diff| diff.ignore_nullability())
938                    .collect::<Vec<_>>();
939
940                if !diffs.is_empty() {
941                    return Err(TypeError::MismatchColumns {
942                        source: expr,
943                        got: column_types,
944                        expected: ctx_typ.clone(),
945                        diffs,
946                        message: "annotation did not match context type".to_string(),
947                    });
948                }
949
950                Ok(column_types)
951            }
952            Project { input, outputs } => {
953                let t_in = tc.typecheck(input, ctx)?;
954
955                for x in outputs {
956                    if *x >= t_in.len() {
957                        return Err(TypeError::BadProject {
958                            source: expr,
959                            got: outputs.clone(),
960                            input_type: t_in,
961                        });
962                    }
963                }
964
965                Ok(outputs.iter().map(|col| t_in[*col].clone()).collect())
966            }
967            Map { input, scalars } => {
968                let mut t_in = tc.typecheck(input, ctx)?;
969
970                for scalar_expr in scalars.iter() {
971                    t_in.push(tc.typecheck_scalar(scalar_expr, expr, &t_in)?);
972
973                    if self.disallow_dummy && scalar_expr.contains_dummy() {
974                        return Err(TypeError::DisallowedDummy {
975                            source: expr,
976                        });
977                    }
978                }
979
980                Ok(t_in)
981            }
982            FlatMap { input, func, exprs } => {
983                let mut t_in = tc.typecheck(input, ctx)?;
984
985                for scalar_expr in exprs {
986                    // TODO(mgree) check result agrees with `func`'s input type
987                    let _t_expr = tc.typecheck_scalar(scalar_expr, expr, &t_in)?;
988
989                    if self.disallow_dummy && scalar_expr.contains_dummy() {
990                        return Err(TypeError::DisallowedDummy {
991                            source: expr,
992                        });
993                    }
994                }
995
996                let t_out: Vec<ReprColumnType> = func
997                    .output_type().column_types;
998
999                // FlatMap extends the existing columns
1000                t_in.extend(t_out);
1001                Ok(t_in)
1002            }
1003            Filter { input, predicates } => {
1004                let mut t_in = tc.typecheck(input, ctx)?;
1005
1006                // Set as nonnull any columns where null values would cause
1007                // any predicate to evaluate to null.
1008                for column in non_nullable_columns(predicates) {
1009                    t_in[column].nullable = false;
1010                }
1011
1012                for scalar_expr in predicates {
1013                    let t = tc.typecheck_scalar(scalar_expr, expr, &t_in)?;
1014
1015                    // filter condition must be boolean
1016                    // ignoring nullability: null is treated as false
1017                    // NB this behavior is slightly different from columns_match (for which we would set nullable to false in the expected type)
1018                    if t.scalar_type != ReprScalarType::Bool {
1019                        let sub = t.scalar_type.clone();
1020
1021                        return Err(TypeError::MismatchColumn {
1022                            source: expr,
1023                            got: t,
1024                            expected: ReprColumnType {
1025                                scalar_type: ReprScalarType::Bool,
1026                                nullable: true,
1027                            },
1028                            diffs: vec![ReprColumnTypeDifference::NotSubtype {
1029                                sub,
1030                                sup: ReprScalarType::Bool,
1031                            }],
1032                            message: "expected boolean condition".to_string(),
1033                        });
1034                    }
1035
1036                    if self.disallow_dummy && scalar_expr.contains_dummy() {
1037                        return Err(TypeError::DisallowedDummy {
1038                            source: expr,
1039                        });
1040                    }
1041                }
1042
1043                Ok(t_in)
1044            }
1045            Join {
1046                inputs,
1047                equivalences,
1048                implementation,
1049            } => {
1050                let mut t_in_global = Vec::new();
1051                let mut t_in_local = vec![Vec::new(); inputs.len()];
1052
1053                for (i, input) in inputs.iter().enumerate() {
1054                    let input_t = tc.typecheck(input, ctx)?;
1055                    t_in_global.extend(input_t.clone());
1056                    t_in_local[i] = input_t;
1057                }
1058
1059                for eq_class in equivalences {
1060                    let mut t_exprs: Vec<ReprColumnType> = Vec::with_capacity(eq_class.len());
1061
1062                    let mut all_nullable = true;
1063
1064                    for scalar_expr in eq_class {
1065                        // Note: the equivalences have global column references
1066                        let t_expr = tc.typecheck_scalar(scalar_expr, expr, &t_in_global)?;
1067
1068                        if !t_expr.nullable {
1069                            all_nullable = false;
1070                        }
1071
1072                        if let Some(t_first) = t_exprs.get(0) {
1073                            let diffs = scalar_subtype_difference(
1074                                &t_expr.scalar_type,
1075                                &t_first.scalar_type,
1076                            );
1077                            if !diffs.is_empty() {
1078                                return Err(TypeError::MismatchColumn {
1079                                    source: expr,
1080                                    got: t_expr,
1081                                    expected: t_first.clone(),
1082                                    diffs,
1083                                    message: "equivalence class members \
1084                                        have different scalar types"
1085                                        .to_string(),
1086                                });
1087                            }
1088
1089                            // equivalences may or may not match on nullability
1090                            // before JoinImplementation runs, nullability should match.
1091                            // but afterwards, some nulls may appear that are actually being filtered out elsewhere
1092                            if self.strict_join_equivalences {
1093                                if t_expr.nullable != t_first.nullable {
1094                                    let sub = t_expr.clone();
1095                                    let sup = t_first.clone();
1096
1097                                    let err = TypeError::MismatchColumn {
1098                                        source: expr,
1099                                        got: t_expr.clone(),
1100                                        expected: t_first.clone(),
1101                                        diffs: vec![
1102                                            ReprColumnTypeDifference::Nullability { sub, sup },
1103                                        ],
1104                                        message: "equivalence class members have \
1105                                            different nullability (and join \
1106                                            equivalence checking is strict)"
1107                                            .to_string(),
1108                                    };
1109
1110                                    // TODO(mgree) this imprecision should be resolved, but we need to fix the optimizer
1111                                    ::tracing::debug!("{err}");
1112                                }
1113                            }
1114                        }
1115
1116                        if self.disallow_dummy && scalar_expr.contains_dummy() {
1117                            return Err(TypeError::DisallowedDummy {
1118                                source: expr,
1119                            });
1120                        }
1121
1122                        t_exprs.push(t_expr);
1123                    }
1124
1125                    if self.strict_join_equivalences && all_nullable {
1126                        let err = TypeError::BadJoinEquivalence {
1127                            source: expr,
1128                            got: t_exprs,
1129                            message: "all expressions were nullable (and join equivalence checking is strict)".to_string(),
1130                        };
1131
1132                        // TODO(mgree) this imprecision should be resolved, but we need to fix the optimizer
1133                        ::tracing::debug!("{err}");
1134                    }
1135                }
1136
1137                // check that the join implementation is consistent
1138                match implementation {
1139                    JoinImplementation::Differential((start_idx, first_key, _), others) => {
1140                        if let Some(key) = first_key {
1141                            for k in key {
1142                                let _ = tc.typecheck_scalar(k, expr, &t_in_local[*start_idx])?;
1143                            }
1144                        }
1145
1146                        for (idx, key, _) in others {
1147                            for k in key {
1148                                let _ = tc.typecheck_scalar(k, expr, &t_in_local[*idx])?;
1149                            }
1150                        }
1151                    }
1152                    JoinImplementation::DeltaQuery(plans) => {
1153                        for plan in plans {
1154                            for (idx, key, _) in plan {
1155                                for k in key {
1156                                    let _ = tc.typecheck_scalar(k, expr, &t_in_local[*idx])?;
1157                                }
1158                            }
1159                        }
1160                    }
1161                    JoinImplementation::IndexedFilter(_coll_id, _idx_id, key, consts) => {
1162                        let typ: Vec<ReprColumnType> = key
1163                            .iter()
1164                            .map(|k| tc.typecheck_scalar(k, expr, &t_in_global))
1165                            .collect::<Result<Vec<ReprColumnType>, TypeError>>()?;
1166
1167                        for row in consts {
1168                            let datums = row.unpack();
1169
1170                            row_difference_with_column_types(expr, &datums, &typ)?;
1171                        }
1172                    }
1173                    JoinImplementation::Unimplemented => (),
1174                }
1175
1176                Ok(t_in_global)
1177            }
1178            Reduce {
1179                input,
1180                group_key,
1181                aggregates,
1182                monotonic: _,
1183                expected_group_size: _,
1184            } => {
1185                let t_in = tc.typecheck(input, ctx)?;
1186
1187                let mut t_out = group_key
1188                    .iter()
1189                    .map(|scalar_expr| tc.typecheck_scalar(scalar_expr, expr, &t_in))
1190                    .collect::<Result<Vec<_>, _>>()?;
1191
1192                    if self.disallow_dummy
1193                        && group_key
1194                            .iter()
1195                            .any(|scalar_expr| scalar_expr.contains_dummy())
1196                    {
1197                        return Err(TypeError::DisallowedDummy {
1198                            source: expr,
1199                        });
1200                    }
1201
1202                for agg in aggregates {
1203                    t_out.push(tc.typecheck_aggregate(agg, expr, &t_in)?);
1204                }
1205
1206                Ok(t_out)
1207            }
1208            TopK {
1209                input,
1210                group_key,
1211                order_key,
1212                limit: _,
1213                offset: _,
1214                monotonic: _,
1215                expected_group_size: _,
1216            } => {
1217                let t_in = tc.typecheck(input, ctx)?;
1218
1219                for &k in group_key {
1220                    if k >= t_in.len() {
1221                        return Err(TypeError::BadTopKGroupKey {
1222                            source: expr,
1223                            k,
1224                            input_type: t_in,
1225                        });
1226                    }
1227                }
1228
1229                for order in order_key {
1230                    if order.column >= t_in.len() {
1231                        return Err(TypeError::BadTopKOrdering {
1232                            source: expr,
1233                            order: order.clone(),
1234                            input_type: t_in,
1235                        });
1236                    }
1237                }
1238
1239                Ok(t_in)
1240            }
1241            Negate { input } => tc.typecheck(input, ctx),
1242            Threshold { input } => tc.typecheck(input, ctx),
1243            Union { base, inputs } => {
1244                let mut t_base = tc.typecheck(base, ctx)?;
1245
1246                for input in inputs {
1247                    let t_input = tc.typecheck(input, ctx)?;
1248
1249                    let len_sub = t_base.len();
1250                    let len_sup = t_input.len();
1251                    if len_sub != len_sup {
1252                        return Err(TypeError::MismatchColumns {
1253                            source: expr,
1254                            got: t_base.clone(),
1255                            expected: t_input,
1256                            diffs: vec![ReprRelationTypeDifference::Length {
1257                                len_sub,
1258                                len_sup,
1259                            }],
1260                            message: "Union branches have different numbers of columns".to_string(),
1261                        });
1262                    }
1263
1264                    for (base_col, input_col) in t_base.iter_mut().zip_eq(t_input) {
1265                        let diffs = column_union(base_col, &input_col);
1266                        if !diffs.is_empty() {
1267                            return Err(TypeError::MismatchColumn {
1268                                    source: expr,
1269                                    got: input_col,
1270                                    expected: base_col.clone(),
1271                                    diffs,
1272                                    message:
1273                                        "couldn't compute union of column types in Union"
1274                                    .to_string(),
1275                            });
1276                        }
1277
1278                    }
1279                }
1280
1281                Ok(t_base)
1282            }
1283            Let { id, value, body } => {
1284                let t_value = tc.typecheck(value, ctx)?;
1285
1286                let binding = Id::Local(*id);
1287                if ctx.contains_key(&binding) {
1288                    return Err(TypeError::Shadowing {
1289                        source: expr,
1290                        id: binding,
1291                    });
1292                }
1293
1294                let mut body_ctx = ctx.clone();
1295                body_ctx.insert(Id::Local(*id), t_value);
1296
1297                tc.typecheck(body, &body_ctx)
1298            }
1299            LetRec { ids, values, body, limits: _ } => {
1300                if ids.len() != values.len() {
1301                    return Err(TypeError::BadLetRecBindings { source: expr });
1302                }
1303
1304                // temporary hack: steal info from the Gets inside to learn the expected types
1305                // if no get occurs in any definition or the body, that means that relation is dead code (which is okay)
1306                let mut ctx = ctx.clone();
1307                // calling tc.collect_recursive_variable_types(expr, ...) triggers a panic due to nested letrecs with shadowing IDs
1308                for inner_expr in values.iter().chain(std::iter::once(body.as_ref())) {
1309                    tc.collect_recursive_variable_types(inner_expr, ids, &mut ctx)?;
1310                }
1311
1312                for (id, value) in ids.iter().zip_eq(values.iter()) {
1313                    let typ = tc.typecheck(value, &ctx)?;
1314
1315                    let id = Id::Local(id.clone());
1316                    if let Some(ctx_typ) = ctx.get_mut(&id) {
1317                        for (base_col, input_col) in ctx_typ.iter_mut().zip_eq(typ) {
1318                            // we expect an EXACT match, but don't care about nullability
1319                            let diffs = column_union(base_col, &input_col);
1320                            if !diffs.is_empty() {
1321                                 return Err(TypeError::MismatchColumn {
1322                                        source: expr,
1323                                        got: input_col,
1324                                        expected: base_col.clone(),
1325                                        diffs,
1326                                        message:
1327                                            "couldn't compute union of column types in LetRec"
1328                                        .to_string(),
1329                                    })
1330                            }
1331                        }
1332                    } else {
1333                        // dead code: no `Get` references this relation anywhere. we record the type anyway
1334                        ctx.insert(id, typ);
1335                    }
1336                }
1337
1338                tc.typecheck(body, &ctx)
1339            }
1340            ArrangeBy { input, keys } => {
1341                let t_in = tc.typecheck(input, ctx)?;
1342
1343                for key in keys {
1344                    for k in key {
1345                        let _ = tc.typecheck_scalar(k, expr, &t_in)?;
1346                    }
1347                }
1348
1349                Ok(t_in)
1350            }
1351        })
1352    }
1353
1354    /// Traverses a term to collect the types of given ids.
1355    ///
1356    /// LetRec doesn't have type info stored in it. Until we change the MIR to track that information explicitly, we have to rebuild it from looking at the term.
1357    fn collect_recursive_variable_types<'a>(
1358        &self,
1359        expr: &'a MirRelationExpr,
1360        ids: &[LocalId],
1361        ctx: &mut Context,
1362    ) -> Result<(), TypeError<'a>> {
1363        use MirRelationExpr::*;
1364
1365        self.checked_recur(|tc| {
1366            match expr {
1367                Get {
1368                    id: Id::Local(id),
1369                    typ,
1370                    ..
1371                } => {
1372                    if !ids.contains(id) {
1373                        return Ok(());
1374                    }
1375
1376                    let id = Id::Local(id.clone());
1377                    if let Some(ctx_typ) = ctx.get_mut(&id) {
1378                        let typ = typ.column_types.iter().cloned().collect_vec();
1379
1380                        if ctx_typ.len() != typ.len() {
1381                            let diffs = relation_subtype_difference(&typ, ctx_typ);
1382
1383                            return Err(TypeError::MismatchColumns {
1384                                source: expr,
1385                                got: typ,
1386                                expected: ctx_typ.clone(),
1387                                diffs,
1388                                message: "environment and type annotation did not match"
1389                                    .to_string(),
1390                            });
1391                        }
1392
1393                        for (base_col, input_col) in ctx_typ.iter_mut().zip_eq(typ) {
1394                            let diffs = column_union(base_col, &input_col);
1395                            if !diffs.is_empty() {
1396                                return Err(TypeError::MismatchColumn {
1397                                    source: expr,
1398                                    got: input_col,
1399                                    expected: base_col.clone(),
1400                                    diffs,
1401                                    message:
1402                                        "couldn't compute union of column types in Get and context"
1403                                            .to_string(),
1404                                });
1405                            }
1406                        }
1407                    } else {
1408                        ctx.insert(id, typ.column_types.iter().cloned().collect_vec());
1409                    }
1410                }
1411                Get {
1412                    id: Id::Global(..), ..
1413                }
1414                | Constant { .. } => (),
1415                Let { id, value, body } => {
1416                    tc.collect_recursive_variable_types(value, ids, ctx)?;
1417
1418                    // we've shadowed the id
1419                    if ids.contains(id) {
1420                        return Err(TypeError::Shadowing {
1421                            source: expr,
1422                            id: Id::Local(*id),
1423                        });
1424                    }
1425
1426                    tc.collect_recursive_variable_types(body, ids, ctx)?;
1427                }
1428                LetRec {
1429                    ids: inner_ids,
1430                    values,
1431                    body,
1432                    limits: _,
1433                } => {
1434                    for inner_id in inner_ids {
1435                        if ids.contains(inner_id) {
1436                            return Err(TypeError::Shadowing {
1437                                source: expr,
1438                                id: Id::Local(*inner_id),
1439                            });
1440                        }
1441                    }
1442
1443                    for value in values {
1444                        tc.collect_recursive_variable_types(value, ids, ctx)?;
1445                    }
1446
1447                    tc.collect_recursive_variable_types(body, ids, ctx)?;
1448                }
1449                Project { input, .. }
1450                | Map { input, .. }
1451                | FlatMap { input, .. }
1452                | Filter { input, .. }
1453                | Reduce { input, .. }
1454                | TopK { input, .. }
1455                | Negate { input }
1456                | Threshold { input }
1457                | ArrangeBy { input, .. } => {
1458                    tc.collect_recursive_variable_types(input, ids, ctx)?;
1459                }
1460                Join { inputs, .. } => {
1461                    for input in inputs {
1462                        tc.collect_recursive_variable_types(input, ids, ctx)?;
1463                    }
1464                }
1465                Union { base, inputs } => {
1466                    tc.collect_recursive_variable_types(base, ids, ctx)?;
1467
1468                    for input in inputs {
1469                        tc.collect_recursive_variable_types(input, ids, ctx)?;
1470                    }
1471                }
1472            }
1473
1474            Ok(())
1475        })
1476    }
1477
1478    fn typecheck_scalar<'a>(
1479        &self,
1480        expr: &'a MirScalarExpr,
1481        source: &'a MirRelationExpr,
1482        column_types: &[ReprColumnType],
1483    ) -> Result<ReprColumnType, TypeError<'a>> {
1484        use MirScalarExpr::*;
1485
1486        self.checked_recur(|tc| match expr {
1487            Column(i, _) => match column_types.get(*i) {
1488                Some(ty) => Ok(ty.clone()),
1489                None => Err(TypeError::NoSuchColumn {
1490                    source,
1491                    expr,
1492                    col: *i,
1493                }),
1494            },
1495            Literal(row, typ) => {
1496                let typ = typ.clone();
1497                if let Ok(row) = row {
1498                    let datums = row.unpack();
1499
1500                    row_difference_with_column_types(source, &datums, std::slice::from_ref(&typ))?;
1501                }
1502
1503                Ok(typ)
1504            }
1505            CallUnmaterializable(func) => Ok(func.output_type()),
1506            CallUnary { expr, func } => {
1507                let typ_in = tc.typecheck_scalar(expr, source, column_types)?;
1508                let typ_out = func.output_type(typ_in);
1509                Ok(typ_out)
1510            }
1511            CallBinary { expr1, expr2, func } => {
1512                let typ_in1 = tc.typecheck_scalar(expr1, source, column_types)?;
1513                let typ_in2 = tc.typecheck_scalar(expr2, source, column_types)?;
1514                let typ_out = func.output_type(&[typ_in1, typ_in2]);
1515                Ok(typ_out)
1516            }
1517            CallVariadic { exprs, func } => Ok(func.output_type(
1518                exprs
1519                    .iter()
1520                    .map(|e| tc.typecheck_scalar(e, source, column_types))
1521                    .collect::<Result<Vec<_>, TypeError>>()?,
1522            )),
1523            If { cond, then, els } => {
1524                let cond_type = tc.typecheck_scalar(cond, source, column_types)?;
1525
1526                // condition must be boolean
1527                // ignoring nullability: null is treated as false
1528                // NB this behavior is slightly different from columns_match (for which we would set nullable to false in the expected type)
1529                if cond_type.scalar_type != ReprScalarType::Bool {
1530                    let sub = cond_type.scalar_type.clone();
1531
1532                    return Err(TypeError::MismatchColumn {
1533                        source,
1534                        got: cond_type,
1535                        expected: ReprColumnType {
1536                            scalar_type: ReprScalarType::Bool,
1537                            nullable: true,
1538                        },
1539                        diffs: vec![ReprColumnTypeDifference::NotSubtype {
1540                            sub,
1541                            sup: ReprScalarType::Bool,
1542                        }],
1543                        message: "expected boolean condition".to_string(),
1544                    });
1545                }
1546
1547                let mut then_type = tc.typecheck_scalar(then, source, column_types)?;
1548                let else_type = tc.typecheck_scalar(els, source, column_types)?;
1549
1550                let diffs = column_union(&mut then_type, &else_type);
1551                if !diffs.is_empty() {
1552                    return Err(TypeError::MismatchColumn {
1553                        source,
1554                        got: then_type,
1555                        expected: else_type,
1556                        diffs,
1557                        message: "couldn't compute union of column types for If".to_string(),
1558                    });
1559                }
1560
1561                Ok(then_type)
1562            }
1563        })
1564    }
1565
1566    /// Typecheck an `AggregateExpr`
1567    pub fn typecheck_aggregate<'a>(
1568        &self,
1569        expr: &'a AggregateExpr,
1570        source: &'a MirRelationExpr,
1571        column_types: &[ReprColumnType],
1572    ) -> Result<ReprColumnType, TypeError<'a>> {
1573        self.checked_recur(|tc| {
1574            let t_in = tc.typecheck_scalar(&expr.expr, source, column_types)?;
1575
1576            // TODO check that t_in is actually acceptable for `func`
1577
1578            Ok(expr.func.output_type(t_in))
1579        })
1580    }
1581}
1582
1583/// Detailed type error logging as a warning, with failures in CI and a logged error in production
1584///
1585/// type_error(severity, ...) logs a type warning; if `severity` is `true`, it will also log an error (visible in Sentry)
1586macro_rules! type_error {
1587    ($severity:expr, $($arg:tt)+) => {{
1588        if $severity {
1589          soft_panic_or_log!($($arg)+);
1590        } else {
1591          ::tracing::debug!($($arg)+);
1592        }
1593    }}
1594}
1595
1596impl crate::Transform for Typecheck {
1597    fn name(&self) -> &'static str {
1598        "Typecheck"
1599    }
1600
1601    fn actually_perform_transform(
1602        &self,
1603        relation: &mut MirRelationExpr,
1604        transform_ctx: &mut crate::TransformCtx,
1605    ) -> Result<(), crate::TransformError> {
1606        let mut typecheck_ctx = self.ctx.lock().expect("typecheck ctx");
1607
1608        let expected = transform_ctx
1609            .global_id
1610            .map_or_else(|| None, |id| typecheck_ctx.get(&Id::Global(id)));
1611
1612        if let Some(id) = transform_ctx.global_id {
1613            if self.disallow_new_globals
1614                && expected.is_none()
1615                && transform_ctx.global_id.is_some()
1616                && !id.is_transient()
1617            {
1618                type_error!(
1619                    false, // not severe
1620                    "type warning: new non-transient global id {id}\n{}",
1621                    relation.pretty()
1622                );
1623            }
1624        }
1625
1626        let got = self.typecheck(relation, &typecheck_ctx);
1627
1628        let humanizer = mz_repr::explain::DummyHumanizer;
1629
1630        match (got, expected) {
1631            (Ok(got), Some(expected)) => {
1632                let id = transform_ctx.global_id.unwrap();
1633
1634                // contravariant: global types can be updated
1635                let diffs = relation_subtype_difference(expected, &got);
1636                if !diffs.is_empty() {
1637                    // SEVERE only if got and expected have true differences, not just nullability
1638                    let severity = diffs
1639                        .iter()
1640                        .any(|diff| diff.clone().ignore_nullability().is_some());
1641
1642                    let err = TypeError::MismatchColumns {
1643                        source: relation,
1644                        got,
1645                        expected: expected.clone(),
1646                        diffs,
1647                        message: format!(
1648                            "a global id {id}'s type changed (was `expected` which should be a subtype of `got`) "
1649                        ),
1650                    };
1651
1652                    type_error!(severity, "type error in known global id {id}:\n{err}");
1653                }
1654            }
1655            (Ok(got), None) => {
1656                if let Some(id) = transform_ctx.global_id {
1657                    typecheck_ctx.insert(Id::Global(id), got);
1658                }
1659            }
1660            (Err(err), _) => {
1661                let (expected, binding) = match expected {
1662                    Some(expected) => {
1663                        let id = transform_ctx.global_id.unwrap();
1664                        (
1665                            format!("expected type {}\n", columns_pretty(expected, &humanizer)),
1666                            format!("known global id {id}"),
1667                        )
1668                    }
1669                    None => ("".to_string(), "transient query".to_string()),
1670                };
1671
1672                type_error!(
1673                    true, // SEVERE: the transformed code is inconsistent
1674                    "type error in {binding}:\n{err}\n{expected}{}",
1675                    relation.pretty()
1676                );
1677            }
1678        }
1679
1680        Ok(())
1681    }
1682}
1683
1684/// Prints a type prettily with a given `ExprHumanizer`
1685pub fn columns_pretty<H>(cols: &[ReprColumnType], humanizer: &H) -> String
1686where
1687    H: ExprHumanizer,
1688{
1689    let mut s = String::with_capacity(2 + 3 * cols.len());
1690
1691    s.push('(');
1692
1693    let mut it = cols.iter().peekable();
1694    while let Some(col) = it.next() {
1695        s.push_str(&humanizer.humanize_column_type(col));
1696
1697        if it.peek().is_some() {
1698            s.push_str(", ");
1699        }
1700    }
1701
1702    s.push(')');
1703
1704    s
1705}
1706
1707impl ReprRelationTypeDifference {
1708    /// Pretty prints a type difference
1709    ///
1710    /// Always indents two spaces
1711    pub fn humanize<H>(&self, h: &H, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result
1712    where
1713        H: ExprHumanizer,
1714    {
1715        use ReprRelationTypeDifference::*;
1716        match self {
1717            Length { len_sub, len_sup } => {
1718                writeln!(
1719                    f,
1720                    "  number of columns do not match ({len_sub} != {len_sup})"
1721                )
1722            }
1723            Column { col, diff } => {
1724                writeln!(f, "  column {col} differs:")?;
1725                diff.humanize(4, h, f)
1726            }
1727        }
1728    }
1729}
1730
1731impl ReprColumnTypeDifference {
1732    /// Pretty prints a type difference at a given indentation level
1733    pub fn humanize<H>(
1734        &self,
1735        indent: usize,
1736        h: &H,
1737        f: &mut std::fmt::Formatter<'_>,
1738    ) -> std::fmt::Result
1739    where
1740        H: ExprHumanizer,
1741    {
1742        use ReprColumnTypeDifference::*;
1743
1744        // indent
1745        write!(f, "{:indent$}", "")?;
1746
1747        match self {
1748            NotSubtype { sub, sup } => {
1749                let sub = h.humanize_scalar_type(sub);
1750                let sup = h.humanize_scalar_type(sup);
1751
1752                writeln!(f, "{sub} is a not a subtype of {sup}")
1753            }
1754            Nullability { sub, sup } => {
1755                let sub = h.humanize_column_type(sub);
1756                let sup = h.humanize_column_type(sup);
1757
1758                writeln!(f, "{sub} is nullable but {sup} is not")
1759            }
1760            ElementType { ctor, element_type } => {
1761                writeln!(f, "{ctor} element types differ:")?;
1762
1763                element_type.humanize(indent + 2, h, f)
1764            }
1765            RecordMissingFields { missing } => {
1766                write!(f, "missing column fields:")?;
1767                for col in missing {
1768                    write!(f, " {col}")?;
1769                }
1770                f.write_char('\n')
1771            }
1772            RecordFields { fields } => {
1773                writeln!(f, "{} record fields differ:", fields.len())?;
1774
1775                for (i, diff) in fields.iter().enumerate() {
1776                    writeln!(f, "{:indent$}  field {i}:", "")?;
1777                    diff.humanize(indent + 4, h, f)?;
1778                }
1779                Ok(())
1780            }
1781        }
1782    }
1783}
1784
1785impl DatumTypeDifference {
1786    /// Pretty prints a type difference at a given indentation level
1787    pub fn humanize<H>(
1788        &self,
1789        indent: usize,
1790        h: &H,
1791        f: &mut std::fmt::Formatter<'_>,
1792    ) -> std::fmt::Result
1793    where
1794        H: ExprHumanizer,
1795    {
1796        // indent
1797        write!(f, "{:indent$}", "")?;
1798
1799        match self {
1800            DatumTypeDifference::Null { expected } => {
1801                let expected = h.humanize_scalar_type(expected);
1802                writeln!(
1803                    f,
1804                    "unexpected null, expected representation type {expected}"
1805                )?
1806            }
1807            DatumTypeDifference::Mismatch {
1808                got_debug,
1809                expected,
1810            } => {
1811                let expected = h.humanize_scalar_type(expected);
1812                // NB `got_debug` will be redacted as appropriate
1813                writeln!(
1814                    f,
1815                    "got datum {got_debug}, expected representation type {expected}"
1816                )?;
1817            }
1818            DatumTypeDifference::MismatchDimensions {
1819                ctor,
1820                got,
1821                expected,
1822            } => {
1823                writeln!(
1824                    f,
1825                    "{ctor} dimensions differ: got datum with dimension {got}, expected dimension {expected}"
1826                )?;
1827            }
1828            DatumTypeDifference::ElementType { ctor, element_type } => {
1829                writeln!(f, "{ctor} element types differ:")?;
1830                element_type.humanize(indent + 4, h, f)?;
1831            }
1832        }
1833
1834        Ok(())
1835    }
1836}
1837
1838/// Wrapper struct for a `Display` instance for `TypeError`s with a given `ExprHumanizer`
1839#[allow(missing_debug_implementations)]
1840pub struct TypeErrorHumanizer<'a, 'b, H>
1841where
1842    H: ExprHumanizer,
1843{
1844    err: &'a TypeError<'a>,
1845    humanizer: &'b H,
1846}
1847
1848impl<'a, 'b, H> TypeErrorHumanizer<'a, 'b, H>
1849where
1850    H: ExprHumanizer,
1851{
1852    /// Create a `Display`-shim struct for a given `TypeError`/`ExprHumanizer` pair
1853    pub fn new(err: &'a TypeError, humanizer: &'b H) -> Self {
1854        Self { err, humanizer }
1855    }
1856}
1857
1858impl<'a, 'b, H> std::fmt::Display for TypeErrorHumanizer<'a, 'b, H>
1859where
1860    H: ExprHumanizer,
1861{
1862    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1863        self.err.humanize(self.humanizer, f)
1864    }
1865}
1866
1867impl<'a> std::fmt::Display for TypeError<'a> {
1868    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1869        TypeErrorHumanizer {
1870            err: self,
1871            humanizer: &DummyHumanizer,
1872        }
1873        .fmt(f)
1874    }
1875}
1876
1877impl<'a> TypeError<'a> {
1878    /// The source of the type error
1879    pub fn source(&self) -> Option<&'a MirRelationExpr> {
1880        use TypeError::*;
1881        match self {
1882            Unbound { source, .. }
1883            | NoSuchColumn { source, .. }
1884            | MismatchColumn { source, .. }
1885            | MismatchColumns { source, .. }
1886            | BadConstantRowLen { source, .. }
1887            | BadConstantRow { source, .. }
1888            | BadProject { source, .. }
1889            | BadJoinEquivalence { source, .. }
1890            | BadTopKGroupKey { source, .. }
1891            | BadTopKOrdering { source, .. }
1892            | BadLetRecBindings { source }
1893            | Shadowing { source, .. }
1894            | DisallowedDummy { source, .. } => Some(source),
1895            Recursion { .. } => None,
1896        }
1897    }
1898
1899    fn humanize<H>(&self, humanizer: &H, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result
1900    where
1901        H: ExprHumanizer,
1902    {
1903        if let Some(source) = self.source() {
1904            writeln!(f, "In the MIR term:\n{}\n", source.pretty())?;
1905        }
1906
1907        use TypeError::*;
1908        match self {
1909            Unbound { source: _, id, typ } => {
1910                let typ = columns_pretty(&typ.column_types, humanizer);
1911                writeln!(f, "{id} is unbound\ndeclared type {typ}")?
1912            }
1913            NoSuchColumn {
1914                source: _,
1915                expr,
1916                col,
1917            } => writeln!(f, "{expr} references non-existent column {col}")?,
1918            MismatchColumn {
1919                source: _,
1920                got,
1921                expected,
1922                diffs,
1923                message,
1924            } => {
1925                let got = humanizer.humanize_column_type(got);
1926                let expected = humanizer.humanize_column_type(expected);
1927                writeln!(
1928                    f,
1929                    "mismatched column types: {message}\n      got {got}\nexpected {expected}"
1930                )?;
1931
1932                for diff in diffs {
1933                    diff.humanize(2, humanizer, f)?;
1934                }
1935            }
1936            MismatchColumns {
1937                source: _,
1938                got,
1939                expected,
1940                diffs,
1941                message,
1942            } => {
1943                let got = columns_pretty(got, humanizer);
1944                let expected = columns_pretty(expected, humanizer);
1945
1946                writeln!(
1947                    f,
1948                    "mismatched relation types: {message}\n      got {got}\nexpected {expected}"
1949                )?;
1950
1951                for diff in diffs {
1952                    diff.humanize(humanizer, f)?;
1953                }
1954            }
1955            BadConstantRowLen {
1956                source: _,
1957                got,
1958                expected,
1959            } => {
1960                let expected = columns_pretty(expected, humanizer);
1961                writeln!(
1962                    f,
1963                    "bad constant row\n      row has length {got}\nexpected row of type {expected}"
1964                )?
1965            }
1966            BadConstantRow {
1967                source: _,
1968                mismatches,
1969                expected,
1970            } => {
1971                let expected = columns_pretty(expected, humanizer);
1972
1973                let num_mismatches = mismatches.len();
1974                let plural = if num_mismatches == 1 { "" } else { "es" };
1975                writeln!(
1976                    f,
1977                    "bad constant row\n      got {num_mismatches} mismatch{plural}\nexpected row of type {expected}"
1978                )?;
1979
1980                if num_mismatches > 0 {
1981                    writeln!(f, "")?;
1982                    for (col, diff) in mismatches.iter() {
1983                        writeln!(f, "      column #{col}:")?;
1984                        diff.humanize(8, humanizer, f)?;
1985                    }
1986                }
1987            }
1988            BadProject {
1989                source: _,
1990                got,
1991                input_type,
1992            } => {
1993                let input_type = columns_pretty(input_type, humanizer);
1994
1995                writeln!(
1996                    f,
1997                    "projection of non-existant columns {got:?} from type {input_type}"
1998                )?
1999            }
2000            BadJoinEquivalence {
2001                source: _,
2002                got,
2003                message,
2004            } => {
2005                let got = columns_pretty(got, humanizer);
2006
2007                writeln!(f, "bad join equivalence {got}: {message}")?
2008            }
2009            BadTopKGroupKey {
2010                source: _,
2011                k,
2012                input_type,
2013            } => {
2014                let input_type = columns_pretty(input_type, humanizer);
2015
2016                writeln!(
2017                    f,
2018                    "TopK group key component references invalid column {k} in columns: {input_type}"
2019                )?
2020            }
2021            BadTopKOrdering {
2022                source: _,
2023                order,
2024                input_type,
2025            } => {
2026                let col = order.column;
2027                let num_cols = input_type.len();
2028                let are = if num_cols == 1 { "is" } else { "are" };
2029                let s = if num_cols == 1 { "" } else { "s" };
2030                let input_type = columns_pretty(input_type, humanizer);
2031
2032                // TODO(cloud#8196)
2033                let mode = HumanizedExplain::new(false);
2034                let order = mode.expr(order, None);
2035
2036                writeln!(
2037                    f,
2038                    "TopK ordering {order} references invalid column {col}\nthere {are} {num_cols} column{s}: {input_type}"
2039                )?
2040            }
2041            BadLetRecBindings { source: _ } => {
2042                writeln!(f, "LetRec ids and definitions don't line up")?
2043            }
2044            Shadowing { source: _, id } => writeln!(f, "id {id} is shadowed")?,
2045            DisallowedDummy { source: _ } => writeln!(f, "contains a dummy value")?,
2046            Recursion { error } => writeln!(f, "{error}")?,
2047        }
2048
2049        Ok(())
2050    }
2051}
2052
2053#[cfg(test)]
2054mod tests {
2055    use mz_ore::{assert_err, assert_ok};
2056    use mz_repr::{SqlColumnType, arb_datum, arb_datum_for_column};
2057    use proptest::prelude::*;
2058
2059    use super::*;
2060
2061    #[mz_ore::test]
2062    fn test_datum_type_difference() {
2063        let datum = Datum::Int16(1);
2064
2065        assert_ok!(datum_difference_with_column_type(
2066            &datum,
2067            &ReprColumnType {
2068                scalar_type: ReprScalarType::Int16,
2069                nullable: true,
2070            }
2071        ));
2072
2073        assert_err!(datum_difference_with_column_type(
2074            &datum,
2075            &ReprColumnType {
2076                scalar_type: ReprScalarType::Int32,
2077                nullable: false,
2078            }
2079        ));
2080    }
2081
2082    proptest! {
2083        #![proptest_config(ProptestConfig {
2084            cases: 5000,
2085            max_global_rejects: 2500,
2086            ..Default::default()
2087        })]
2088        #[mz_ore::test]
2089        #[cfg_attr(miri, ignore)]
2090        fn datum_type_difference_with_instance_of_on_valid_data(
2091            (src, datum) in any::<SqlColumnType>()
2092                .prop_flat_map(|src| {
2093                    let datum = arb_datum_for_column(src.clone());
2094                    (Just(src), datum)
2095                })
2096        ) {
2097            let typ = ReprColumnType::from(&src);
2098            let datum = Datum::from(&datum);
2099
2100            if datum.contains_dummy() {
2101                return Err(TestCaseError::reject("datum contains a dummy"));
2102            }
2103
2104            let diff = datum_difference_with_column_type(&datum, &typ);
2105            if datum.is_instance_of(&typ) {
2106                assert_ok!(diff);
2107            } else {
2108                assert_err!(diff);
2109            }
2110        }
2111    }
2112
2113    proptest! {
2114        // We run many cases because the data are _random_, and we want to be sure
2115        // that we have covered sufficient cases. We drop dummy data (behavior is different!) so we allow many more global rejects.
2116        #![proptest_config(ProptestConfig::with_cases(10000))]
2117        #[mz_ore::test]
2118        #[cfg_attr(miri, ignore)]
2119        fn datum_type_difference_agrees_with_is_instance_of_on_random_data(
2120            src in any::<SqlColumnType>(),
2121            datum in arb_datum(false),
2122        ) {
2123            let typ = ReprColumnType::from(&src);
2124            let datum = Datum::from(&datum);
2125
2126            assert!(!datum.contains_dummy(), "datum contains a dummy (bug in arb_datum)");
2127
2128            let diff = datum_difference_with_column_type(&datum, &typ);
2129            if datum.is_instance_of(&typ) {
2130                assert_ok!(diff);
2131            } else {
2132                assert_err!(diff);
2133            }
2134        }
2135    }
2136
2137    #[mz_ore::test]
2138    fn datum_type_difference_github_10039() {
2139        let typ = ReprColumnType {
2140            scalar_type: ReprScalarType::Record {
2141                fields: Box::new([ReprColumnType {
2142                    scalar_type: ReprScalarType::UInt32,
2143                    nullable: false,
2144                }]),
2145            },
2146            nullable: false,
2147        };
2148
2149        let mut row = mz_repr::Row::default();
2150        row.packer()
2151            .push_list(std::iter::once(mz_repr::Datum::Null));
2152        let datum = row.unpack_first();
2153
2154        assert!(!datum.is_instance_of(&typ));
2155        let diff = datum_difference_with_column_type(&datum, &typ);
2156        assert_err!(diff);
2157    }
2158}