Skip to main content

mz_sql/plan/
lowering.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Lowering is the process of transforming a `HirRelationExpr`
11//! into a `MirRelationExpr`.
12//!
13//! The most crucial part of lowering is decorrelation; i.e.: rewriting a
14//! `HirScalarExpr` that may contain subqueries (e.g. `SELECT` or `EXISTS`)
15//! with instances of `MirScalarExpr` that contain none of these.
16//!
17//! Informally, a subquery should be viewed as a query that is executed in
18//! the context of some outer relation, for each row of that relation. The
19//! subqueries often contain references to the columns of the outer
20//! relation.
21//!
22//! The transformation we perform maintains an `outer` relation and then
23//! traverses the relation expression that may contain references to those
24//! outer columns. As subqueries are discovered, the current relation
25//! expression is recast as the outer expression until such a point as the
26//! scalar expression's evaluation can be determined and appended to each
27//! row of the previously outer relation.
28//!
29//! It is important that the outer columns (the initial columns) act as keys
30//! for all nested computation. When counts or other aggregations are
31//! performed, they should include not only the indicated keys but also all
32//! of the outer columns.
33//!
34//! The decorrelation transformation is initialized with an empty outer
35//! relation, but it seems entirely appropriate to decorrelate queries that
36//! contain "holes" from prepared statements, as if the query was a subquery
37//! against a relation containing the assignments of values to those holes.
38
39use std::collections::{BTreeMap, BTreeSet};
40use std::iter::repeat;
41
42use itertools::Itertools;
43use mz_expr::func::variadic;
44use mz_expr::visit::Visit;
45use mz_expr::{AccessStrategy, AggregateFunc, MirRelationExpr, MirScalarExpr, func};
46use mz_ore::collections::CollectionExt;
47use mz_ore::stack::maybe_grow;
48use mz_repr::*;
49
50use crate::optimizer_metrics::OptimizerMetrics;
51use crate::plan::hir::{
52    AggregateExpr, ColumnOrder, ColumnRef, HirRelationExpr, HirScalarExpr, JoinKind, WindowExprType,
53};
54use crate::plan::{PlanError, transform_hir};
55use crate::session::vars::SystemVars;
56
57mod variadic_left;
58
59/// Maps a leveled column reference to a specific column.
60///
61/// Leveled column references are nested, so that larger levels are
62/// found early in a record and level zero is found at the end.
63///
64/// The column map only stores references for levels greater than zero,
65/// and column references at level zero simply start at the first column
66/// after all prior references.
67#[derive(Debug, Clone)]
68struct ColumnMap {
69    inner: BTreeMap<ColumnRef, usize>,
70}
71
72impl ColumnMap {
73    fn empty() -> ColumnMap {
74        Self::new(BTreeMap::new())
75    }
76
77    fn new(inner: BTreeMap<ColumnRef, usize>) -> ColumnMap {
78        ColumnMap { inner }
79    }
80
81    fn get(&self, col_ref: &ColumnRef) -> usize {
82        if col_ref.level == 0 {
83            self.inner.len() + col_ref.column
84        } else {
85            self.inner[col_ref]
86        }
87    }
88
89    fn len(&self) -> usize {
90        self.inner.len()
91    }
92
93    /// Updates references in the `ColumnMap` for use in a nested scope. The
94    /// provided `arity` must specify the arity of the current scope.
95    fn enter_scope(&self, arity: usize) -> ColumnMap {
96        // From the perspective of the nested scope, all existing column
97        // references will be one level greater.
98        let existing = self
99            .inner
100            .clone()
101            .into_iter()
102            .update(|(col, _i)| col.level += 1);
103
104        // All columns in the current scope become explicit entries in the
105        // immediate parent scope.
106        let new = (0..arity).map(|i| {
107            (
108                ColumnRef {
109                    level: 1,
110                    column: i,
111                },
112                self.len() + i,
113            )
114        });
115
116        ColumnMap::new(existing.chain(new).collect())
117    }
118}
119
120/// Map with the CTEs currently in scope.
121type CteMap = BTreeMap<mz_expr::LocalId, CteDesc>;
122
123/// Information about needed when finding a reference to a CTE in scope.
124#[derive(Clone)]
125struct CteDesc {
126    /// The new ID assigned to the lowered version of the CTE, which may not match
127    /// the ID of the input CTE.
128    new_id: mz_expr::LocalId,
129    /// The relation type of the CTE including the columns from the outer
130    /// context at the beginning.
131    relation_type: ReprRelationType,
132    /// The outer relation the CTE was applied to.
133    outer_relation: MirRelationExpr,
134}
135
136#[derive(Debug, Clone, Copy)]
137pub struct Config {
138    /// Enable outer join lowering implemented in database-issues#6747.
139    pub enable_new_outer_join_lowering: bool,
140    /// Enable outer join lowering implemented in database-issues#7561.
141    pub enable_variadic_left_join_lowering: bool,
142    pub enable_guard_subquery_tablefunc: bool,
143    pub enable_cast_elimination: bool,
144    pub enable_simplify_quantified_comparisons: bool,
145}
146
147impl Default for Config {
148    fn default() -> Self {
149        Self {
150            enable_new_outer_join_lowering: false,
151            enable_variadic_left_join_lowering: false,
152            enable_guard_subquery_tablefunc: false,
153            enable_cast_elimination: false,
154            enable_simplify_quantified_comparisons: false,
155        }
156    }
157}
158
159impl From<&SystemVars> for Config {
160    fn from(vars: &SystemVars) -> Self {
161        Self {
162            enable_new_outer_join_lowering: vars.enable_new_outer_join_lowering(),
163            enable_variadic_left_join_lowering: vars.enable_variadic_left_join_lowering(),
164            enable_guard_subquery_tablefunc: vars.enable_guard_subquery_tablefunc(),
165            enable_cast_elimination: vars.enable_cast_elimination(),
166            enable_simplify_quantified_comparisons: vars.enable_simplify_quantified_comparisons(),
167        }
168    }
169}
170
171/// Context passed to the lowering. This is wired to most parts of the lowering.
172pub(crate) struct Context<'a> {
173    /// Feature flags affecting the behavior of lowering.
174    pub config: &'a Config,
175    /// Optional, because some callers don't have an `OptimizerMetrics` handy. When it's None, we
176    /// simply don't write metrics.
177    pub metrics: Option<&'a OptimizerMetrics>,
178}
179
180impl HirRelationExpr {
181    /// Rewrite `self` into a `MirRelationExpr`.
182    /// This requires rewriting all correlated subqueries (nested `HirRelationExpr`s) into flat queries
183    #[mz_ore::instrument(target = "optimizer", level = "trace", name = "hir_to_mir")]
184    pub fn lower<C: Into<Config>>(
185        self,
186        config: C,
187        metrics: Option<&OptimizerMetrics>,
188    ) -> Result<MirRelationExpr, PlanError> {
189        let context = Context {
190            config: &config.into(),
191            metrics,
192        };
193        let result = match self {
194            // We directly rewrite a Constant into the corresponding `MirRelationExpr::Constant`
195            // to ensure that the downstream optimizer can easily bypass most
196            // irrelevant optimizations (e.g. reduce folding) for this expression
197            // without having to re-learn the fact that it is just a constant,
198            // as it would if the constant were wrapped in a Let-Get pair.
199            HirRelationExpr::Constant { rows, typ } => {
200                let rows: Vec<_> = rows.into_iter().map(|row| (row, Diff::ONE)).collect();
201                MirRelationExpr::Constant {
202                    rows: Ok(rows),
203                    typ: ReprRelationType::from(&typ),
204                }
205            }
206            mut other => {
207                let mut id_gen = mz_ore::id_gen::IdGen::default();
208                transform_hir::split_subquery_predicates(&mut other)?;
209                transform_hir::try_simplify_quantified_comparisons(
210                    &mut other,
211                    context.config.enable_simplify_quantified_comparisons,
212                )?;
213                transform_hir::fuse_window_functions(&mut other, &context)?;
214                MirRelationExpr::constant(vec![vec![]], ReprRelationType::new(vec![])).let_in(
215                    &mut id_gen,
216                    |id_gen, get_outer| {
217                        other.applied_to(
218                            id_gen,
219                            get_outer,
220                            &ColumnMap::empty(),
221                            &mut CteMap::new(),
222                            &context,
223                        )
224                    },
225                )?
226            }
227        };
228
229        mz_repr::explain::trace_plan(&result);
230
231        Ok(result)
232    }
233
234    /// Return a `MirRelationExpr` which evaluates `self` once for each row of `get_outer`.
235    ///
236    /// For uncorrelated `self`, this should be the cross-product between `get_outer` and `self`.
237    /// When `self` references columns of `get_outer`, much more work needs to occur.
238    ///
239    /// The `col_map` argument contains mappings to some of the columns of `get_outer`, though
240    /// perhaps not all of them. It should be used as the basis of resolving column references,
241    /// but care must be taken when adding new columns that `get_outer.arity()` is where they
242    /// will start, rather than any function of `col_map`.
243    ///
244    /// The `get_outer` expression should be a `Get` with no duplicate rows, describing the distinct
245    /// assignment of values to outer rows.
246    fn applied_to(
247        self,
248        id_gen: &mut mz_ore::id_gen::IdGen,
249        get_outer: MirRelationExpr,
250        col_map: &ColumnMap,
251        cte_map: &mut CteMap,
252        context: &Context,
253    ) -> Result<MirRelationExpr, PlanError> {
254        maybe_grow(|| {
255            use MirRelationExpr as SR;
256
257            use HirRelationExpr::*;
258
259            if let MirRelationExpr::Get { .. } = &get_outer {
260            } else {
261                panic!(
262                    "get_outer: expected a MirRelationExpr::Get, found\n{}",
263                    get_outer.pretty(),
264                );
265            }
266            assert_eq!(col_map.len(), get_outer.arity());
267            Ok(match self {
268                Constant { rows, typ } => {
269                    // Constant expressions are not correlated with `get_outer`, and should be cross-products.
270                    get_outer.product(SR::Constant {
271                        rows: Ok(rows.into_iter().map(|row| (row, Diff::ONE)).collect()),
272                        typ: ReprRelationType::from(&typ),
273                    })
274                }
275                Get { id, typ } => match id {
276                    mz_expr::Id::Local(local_id) => {
277                        let cte_desc = cte_map.get(&local_id).unwrap();
278                        let get_cte = SR::Get {
279                            id: mz_expr::Id::Local(cte_desc.new_id.clone()),
280                            typ: cte_desc.relation_type.clone(),
281                            access_strategy: AccessStrategy::UnknownOrLocal,
282                        };
283                        if get_outer == cte_desc.outer_relation {
284                            // If the CTE was applied to the same exact relation, we can safely
285                            // return a `Get` relation.
286                            get_cte
287                        } else {
288                            // Otherwise, the new outer relation may contain more columns from some
289                            // intermediate scope placed between the definition of the CTE and this
290                            // reference of the CTE and/or more operations applied on top of the
291                            // outer relation.
292                            //
293                            // An example of the latter is the following query:
294                            //
295                            // SELECT *
296                            // FROM x,
297                            //      LATERAL(WITH a(m) as (SELECT max(y.a) FROM y WHERE y.a < x.a)
298                            //              SELECT (SELECT m FROM a) FROM y) b;
299                            //
300                            // When the CTE is lowered, the outer relation is `Get x`. But then,
301                            // the reference of the CTE is applied to `Distinct(Join(Get x, Get y), x.*)`
302                            // which has the same cardinality as `Get x`.
303                            //
304                            // In any case, `get_outer` is guaranteed to contain the columns of the
305                            // outer relation the CTE was applied to at its prefix. Since, we must
306                            // return a relation containing `get_outer`'s column at the beginning,
307                            // we must build a join between `get_outer` and `get_cte` on their common
308                            // columns.
309                            let oa = get_outer.arity();
310                            let cte_outer_columns = cte_desc.relation_type.arity() - typ.arity();
311                            let equivalences = (0..cte_outer_columns)
312                                .map(|pos| {
313                                    vec![
314                                        MirScalarExpr::column(pos),
315                                        MirScalarExpr::column(pos + oa),
316                                    ]
317                                })
318                                .collect();
319
320                            // Project out the second copy of the common between `get_outer` and
321                            // `cte_desc.outer_relation`.
322                            let projection = (0..oa)
323                                .chain(oa + cte_outer_columns..oa + cte_outer_columns + typ.arity())
324                                .collect_vec();
325                            SR::join_scalars(vec![get_outer, get_cte], equivalences)
326                                .project(projection)
327                        }
328                    }
329                    mz_expr::Id::Global(_) => {
330                        // Get statements are only to external sources, and are not correlated with `get_outer`.
331                        get_outer.product(SR::Get {
332                            id,
333                            typ: ReprRelationType::from(&typ),
334                            access_strategy: AccessStrategy::UnknownOrLocal,
335                        })
336                    }
337                },
338                Let {
339                    name: _,
340                    id,
341                    value,
342                    body,
343                } => {
344                    let value =
345                        value.applied_to(id_gen, get_outer.clone(), col_map, cte_map, context)?;
346                    value.let_in(id_gen, |id_gen, get_value| {
347                        let (new_id, typ) = if let MirRelationExpr::Get {
348                            id: mz_expr::Id::Local(id),
349                            typ,
350                            ..
351                        } = get_value
352                        {
353                            (id, typ)
354                        } else {
355                            panic!(
356                                "get_value: expected a MirRelationExpr::Get with local Id, found\n{}",
357                                get_value.pretty(),
358                            );
359                        };
360                        // Add the information about the CTE to the map and remove it when
361                        // it goes out of scope.
362                        let old_value = cte_map.insert(
363                            id.clone(),
364                            CteDesc {
365                                new_id,
366                                relation_type: typ,
367                                outer_relation: get_outer.clone(),
368                            },
369                        );
370                        let body = body.applied_to(id_gen, get_outer, col_map, cte_map, context);
371                        if let Some(old_value) = old_value {
372                            cte_map.insert(id, old_value);
373                        } else {
374                            cte_map.remove(&id);
375                        }
376                        body
377                    })?
378                }
379                LetRec {
380                    limit,
381                    bindings,
382                    body,
383                } => {
384                    let num_bindings = bindings.len();
385
386                    // We use the outer type with the HIR types to form MIR CTE types.
387                    let outer_column_types = get_outer.typ().column_types;
388
389                    // Rename and introduce all bindings.
390                    let mut shadowed_bindings = Vec::with_capacity(num_bindings);
391                    let mut mir_ids = Vec::with_capacity(num_bindings);
392                    for (_name, id, _value, typ) in bindings.iter() {
393                        let mir_id = mz_expr::LocalId::new(id_gen.allocate_id());
394                        mir_ids.push(mir_id);
395                        let shadowed = cte_map.insert(
396                            id.clone(),
397                            CteDesc {
398                                new_id: mir_id,
399                                relation_type: ReprRelationType::new(
400                                    outer_column_types
401                                        .iter()
402                                        .cloned()
403                                        .chain(typ.column_types.iter().map(ReprColumnType::from))
404                                        .collect::<Vec<_>>(),
405                                ),
406                                outer_relation: get_outer.clone(),
407                            },
408                        );
409                        shadowed_bindings.push((*id, shadowed));
410                    }
411
412                    let mut mir_values = Vec::with_capacity(num_bindings);
413                    for (_name, _id, value, _typ) in bindings.into_iter() {
414                        mir_values.push(value.applied_to(
415                            id_gen,
416                            get_outer.clone(),
417                            col_map,
418                            cte_map,
419                            context,
420                        )?);
421                    }
422
423                    let mir_body = body.applied_to(id_gen, get_outer, col_map, cte_map, context)?;
424
425                    // Remove our bindings and reinstate any shadowed bindings.
426                    for (id, shadowed) in shadowed_bindings {
427                        if let Some(shadowed) = shadowed {
428                            cte_map.insert(id, shadowed);
429                        } else {
430                            cte_map.remove(&id);
431                        }
432                    }
433
434                    MirRelationExpr::LetRec {
435                        ids: mir_ids,
436                        values: mir_values,
437                        // Copy the limit to each binding.
438                        limits: repeat(limit).take(num_bindings).collect(),
439                        body: Box::new(mir_body),
440                    }
441                }
442                Project { input, outputs } => {
443                    // Projections should be applied to the decorrelated `inner`, and to its columns,
444                    // which means rebasing `outputs` to start `get_outer.arity()` columns later.
445                    let input =
446                        input.applied_to(id_gen, get_outer.clone(), col_map, cte_map, context)?;
447                    let outputs = (0..get_outer.arity())
448                        .chain(outputs.into_iter().map(|i| get_outer.arity() + i))
449                        .collect::<Vec<_>>();
450                    input.project(outputs)
451                }
452                Map { input, mut scalars } => {
453                    // Scalar expressions may contain correlated subqueries. We must be cautious!
454
455                    // We lower scalars in chunks, and must keep track of the
456                    // arity of the HIR fragments lowered so far.
457                    let mut lowered_arity = input.arity();
458
459                    let mut input =
460                        input.applied_to(id_gen, get_outer, col_map, cte_map, context)?;
461
462                    // Lower subqueries in maximally sized batches, such as no subquery in the current
463                    // batch depends on columns from the same batch.
464                    // Note that subqueries in this projection may reference columns added by this
465                    // Map operator, so we need to ensure these columns exist before lowering the
466                    // subquery.
467                    while !scalars.is_empty() {
468                        let end_idx = scalars
469                            .iter_mut()
470                            .position(|s| {
471                                let mut requires_nonexistent_column = false;
472                                #[allow(deprecated)]
473                                s.visit_columns(0, &mut |depth, col| {
474                                    if col.level == depth {
475                                        requires_nonexistent_column |= col.column >= lowered_arity
476                                    }
477                                });
478                                requires_nonexistent_column
479                            })
480                            .unwrap_or(scalars.len());
481                        assert!(
482                            end_idx > 0,
483                            "a Map expression references itself or a later column; lowered_arity: {}, expressions: {:?}",
484                            lowered_arity,
485                            scalars
486                        );
487
488                        lowered_arity = lowered_arity + end_idx;
489                        let scalars = scalars.drain(0..end_idx).collect_vec();
490
491                        let old_arity = input.arity();
492                        let (with_subqueries, subquery_map) = HirScalarExpr::lower_subqueries(
493                            &scalars, id_gen, col_map, cte_map, input, context,
494                        )?;
495                        input = with_subqueries;
496
497                        // We will proceed sequentially through the scalar expressions, for each transforming
498                        // the decorrelated `input` into a relation with potentially more columns capable of
499                        // addressing the needs of the scalar expression.
500                        // Having done so, we add the scalar value of interest and trim off any other newly
501                        // added columns.
502                        //
503                        // The sequential traversal is present as expressions are allowed to depend on the
504                        // values of prior expressions.
505                        let mut scalar_columns = Vec::new();
506                        for scalar in scalars {
507                            let scalar = scalar.applied_to(
508                                id_gen,
509                                col_map,
510                                cte_map,
511                                &mut input,
512                                &Some(&subquery_map),
513                                context,
514                            )?;
515                            input = input.map_one(scalar);
516                            scalar_columns.push(input.arity() - 1);
517                        }
518
519                        // Discard any new columns added by the lowering of the scalar expressions
520                        input = input.project((0..old_arity).chain(scalar_columns).collect());
521                    }
522
523                    input
524                }
525                CallTable { func, exprs } => {
526                    // FlatMap expressions may contain correlated subqueries. Unlike Map they are not
527                    // allowed to refer to the results of previous expressions, and we have a simpler
528                    // implementation that appends all relevant columns first, then applies the flatmap
529                    // operator to the result, then strips off any columns introduce by subqueries.
530
531                    let mut input = get_outer;
532                    let old_arity = input.arity();
533
534                    let exprs = exprs
535                        .into_iter()
536                        .map(|e| e.applied_to(id_gen, col_map, cte_map, &mut input, &None, context))
537                        .collect::<Result<Vec<_>, _>>()?;
538
539                    let new_arity = input.arity();
540                    let output_arity = func.output_arity();
541                    input = input.flat_map(func, exprs);
542                    if old_arity != new_arity {
543                        // this means we added some columns to handle subqueries, and now we need to get rid of them
544                        input = input.project(
545                            (0..old_arity)
546                                .chain(new_arity..new_arity + output_arity)
547                                .collect(),
548                        );
549                    }
550                    input
551                }
552                Filter { input, predicates } => {
553                    // Filter expressions may contain correlated subqueries.
554                    // We extend `get_outer` with sufficient values to determine the value of the predicate,
555                    // then filter the results, then strip off any columns that were added for this purpose.
556                    let mut input =
557                        input.applied_to(id_gen, get_outer, col_map, cte_map, context)?;
558                    for predicate in predicates {
559                        let old_arity = input.arity();
560                        let predicate = predicate
561                            .applied_to(id_gen, col_map, cte_map, &mut input, &None, context)?;
562                        let new_arity = input.arity();
563                        input = input.filter(vec![predicate]);
564                        if old_arity != new_arity {
565                            // this means we added some columns to handle subqueries, and now we need to get rid of them
566                            input = input.project((0..old_arity).collect());
567                        }
568                    }
569                    input
570                }
571                Join {
572                    left,
573                    right,
574                    on,
575                    kind,
576                } if right.is_correlated() => {
577                    // A correlated join is a join in which the right expression has
578                    // access to the columns in the left expression. It turns out
579                    // this is *exactly* our branch operator, plus some additional
580                    // null handling in the case of left joins. (Right and full
581                    // lateral joins are not permitted.)
582                    //
583                    // As with normal joins, the `on` predicate may be correlated,
584                    // and we treat it as a filter that follows the branch.
585
586                    assert!(kind.can_be_correlated());
587
588                    let left = left.applied_to(id_gen, get_outer, col_map, cte_map, context)?;
589                    left.let_in(id_gen, |id_gen, get_left| {
590                        let apply_requires_distinct_outer = false;
591                        let mut join = branch(
592                            id_gen,
593                            get_left.clone(),
594                            col_map,
595                            cte_map,
596                            *right,
597                            apply_requires_distinct_outer,
598                            context,
599                            |id_gen, right, get_left, col_map, cte_map, context| {
600                                right.applied_to(id_gen, get_left, col_map, cte_map, context)
601                            },
602                        )?;
603
604                        // Plan the `on` predicate.
605                        let old_arity = join.arity();
606                        let on =
607                            on.applied_to(id_gen, col_map, cte_map, &mut join, &None, context)?;
608                        join = join.filter(vec![on]);
609                        let new_arity = join.arity();
610                        if old_arity != new_arity {
611                            // This means we added some columns to handle
612                            // subqueries, and now we need to get rid of them.
613                            join = join.project((0..old_arity).collect());
614                        }
615
616                        // If a left join, reintroduce any rows from the left that
617                        // are missing, with nulls filled in for the right columns.
618                        if let JoinKind::LeftOuter { .. } = kind {
619                            let default = join
620                                .typ()
621                                .column_types
622                                .into_iter()
623                                .skip(get_left.arity())
624                                .map(|typ| (Datum::Null, typ.scalar_type))
625                                .collect();
626                            get_left.lookup(id_gen, join, default)
627                        } else {
628                            Ok::<_, PlanError>(join)
629                        }
630                    })?
631                }
632                Join {
633                    left,
634                    right,
635                    on,
636                    kind,
637                } => {
638                    if context.config.enable_variadic_left_join_lowering {
639                        // Attempt to extract a stack of left joins.
640                        if let JoinKind::LeftOuter = kind {
641                            let mut rights = vec![(&*right, &on)];
642                            let mut left_test = &left;
643                            while let Join {
644                                left,
645                                right,
646                                on,
647                                kind: JoinKind::LeftOuter,
648                            } = &**left_test
649                            {
650                                rights.push((&**right, on));
651                                left_test = left;
652                            }
653                            if rights.len() > 1 {
654                                // Defensively clone `cte_map` as it may be mutated.
655                                let cte_map_clone = cte_map.clone();
656                                if let Ok(Some(magic)) = variadic_left::attempt_left_join_magic(
657                                    left_test,
658                                    rights,
659                                    id_gen,
660                                    get_outer.clone(),
661                                    col_map,
662                                    cte_map,
663                                    context,
664                                ) {
665                                    return Ok(magic);
666                                } else {
667                                    cte_map.clone_from(&cte_map_clone);
668                                }
669                            }
670                        }
671                    }
672
673                    // Both join expressions should be decorrelated, and then joined by their
674                    // leading columns to form only those pairs corresponding to the same row
675                    // of `get_outer`.
676                    //
677                    // The `on` predicate may contain correlated subqueries, and we treat it
678                    // as though it was a filter, with the caveat that we also translate outer
679                    // joins in this step. The post-filtration results need to be considered
680                    // against the records present in the left and right (decorrelated) inputs,
681                    // depending on the type of join.
682                    let oa = get_outer.arity();
683                    let left =
684                        left.applied_to(id_gen, get_outer.clone(), col_map, cte_map, context)?;
685                    let lt = left.typ().column_types.into_iter().skip(oa).collect_vec();
686                    let la = lt.len();
687                    left.let_in(id_gen, |id_gen, get_left| {
688                        let right_col_map = col_map.enter_scope(0);
689                        let right = right.applied_to(
690                            id_gen,
691                            get_outer.clone(),
692                            &right_col_map,
693                            cte_map,
694                            context,
695                        )?;
696                        let rt = right.typ().column_types.into_iter().skip(oa).collect_vec();
697                        let ra = rt.len();
698                        right.let_in(id_gen, |id_gen, get_right| {
699                            let mut product = SR::join(
700                                vec![get_left.clone(), get_right.clone()],
701                                (0..oa).map(|i| vec![(0, i), (1, i)]).collect(),
702                            )
703                            // Project away the repeated copy of get_outer's columns.
704                            .project(
705                                (0..(oa + la))
706                                    .chain((oa + la + oa)..(oa + la + oa + ra))
707                                    .collect(),
708                            );
709
710                            // Decorrelate and lower the `on` clause.
711                            let on = on.applied_to(
712                                id_gen,
713                                col_map,
714                                cte_map,
715                                &mut product,
716                                &None,
717                                context,
718                            )?;
719                            // Collect the types of all subqueries appearing in
720                            // the `on` clause. The subquery results were
721                            // appended to `product` in the `on.applied_to(...)`
722                            // call above.
723                            let on_subquery_types = product
724                                .typ()
725                                .column_types
726                                .drain(oa + la + ra..)
727                                .collect_vec();
728                            // Remember if `on` had any subqueries.
729                            let on_has_subqueries = !on_subquery_types.is_empty();
730
731                            // Attempt an efficient equijoin implementation, in which outer joins are
732                            // more efficiently rendered than in general. This can return `None` if
733                            // such a plan is not possible, for example if `on` does not describe an
734                            // equijoin between columns of `left` and `right`.
735                            if kind != JoinKind::Inner {
736                                if let Some(joined) = attempt_outer_equijoin(
737                                    get_left.clone(),
738                                    get_right.clone(),
739                                    on.clone(),
740                                    on_subquery_types,
741                                    kind.clone(),
742                                    oa,
743                                    id_gen,
744                                    context,
745                                )? {
746                                    if let Some(metrics) = context.metrics {
747                                        metrics.inc_outer_join_lowering("equi");
748                                    }
749                                    return Ok(joined);
750                                }
751                            }
752
753                            // Otherwise, perform a more general join.
754                            if let Some(metrics) = context.metrics {
755                                metrics.inc_outer_join_lowering("general");
756                            }
757                            let mut join = product.filter(vec![on]);
758                            if on_has_subqueries {
759                                // This means that `on.applied_to(...)` appended
760                                // some columns to handle subqueries, and now we
761                                // need to get rid of them.
762                                join = join.project((0..oa + la + ra).collect());
763                            }
764                            join.let_in(id_gen, |id_gen, get_join| {
765                                let mut result = get_join.clone();
766                                if let JoinKind::LeftOuter { .. } | JoinKind::FullOuter { .. } =
767                                    kind
768                                {
769                                    let left_outer = get_left.clone().anti_lookup::<PlanError>(
770                                        id_gen,
771                                        get_join.clone(),
772                                        rt.into_iter()
773                                            .map(|typ| (Datum::Null, typ.scalar_type))
774                                            .collect(),
775                                    )?;
776                                    result = result.union(left_outer);
777                                }
778                                if let JoinKind::RightOuter | JoinKind::FullOuter = kind {
779                                    let right_outer = get_right
780                                        .clone()
781                                        .anti_lookup::<PlanError>(
782                                            id_gen,
783                                            get_join
784                                                // need to swap left and right to make the anti_lookup work
785                                                .project(
786                                                    (0..oa)
787                                                        .chain((oa + la)..(oa + la + ra))
788                                                        .chain((oa)..(oa + la))
789                                                        .collect(),
790                                                ),
791                                            lt.into_iter()
792                                                .map(|typ| (Datum::Null, typ.scalar_type))
793                                                .collect(),
794                                        )?
795                                        // swap left and right back again
796                                        .project(
797                                            (0..oa)
798                                                .chain((oa + ra)..(oa + ra + la))
799                                                .chain((oa)..(oa + ra))
800                                                .collect(),
801                                        );
802                                    result = result.union(right_outer);
803                                }
804                                Ok::<MirRelationExpr, PlanError>(result)
805                            })
806                        })
807                    })?
808                }
809                Union { base, inputs } => {
810                    // Union is uncomplicated.
811                    SR::Union {
812                        base: Box::new(base.applied_to(
813                            id_gen,
814                            get_outer.clone(),
815                            col_map,
816                            cte_map,
817                            context,
818                        )?),
819                        inputs: inputs
820                            .into_iter()
821                            .map(|input| {
822                                input.applied_to(
823                                    id_gen,
824                                    get_outer.clone(),
825                                    col_map,
826                                    cte_map,
827                                    context,
828                                )
829                            })
830                            .collect::<Result<Vec<_>, _>>()?,
831                    }
832                }
833                Reduce {
834                    input,
835                    group_key,
836                    aggregates,
837                    expected_group_size,
838                } => {
839                    // Reduce may contain expressions with correlated subqueries.
840                    // In addition, here an empty reduction key signifies that we need to supply default values
841                    // in the case that there are no results (as in a SQL aggregation without an explicit GROUP BY).
842                    let mut input =
843                        input.applied_to(id_gen, get_outer.clone(), col_map, cte_map, context)?;
844                    let applied_group_key = (0..get_outer.arity())
845                        .chain(group_key.iter().map(|i| get_outer.arity() + i))
846                        .collect();
847                    let applied_aggregates = aggregates
848                        .into_iter()
849                        .map(|aggregate| {
850                            aggregate.applied_to(id_gen, col_map, cte_map, &mut input, context)
851                        })
852                        .collect::<Result<Vec<_>, _>>()?;
853                    let input_type = input.typ();
854                    let default = applied_aggregates
855                        .iter()
856                        .map(|agg| {
857                            (
858                                agg.func.default(),
859                                agg.typ(&input_type.column_types).scalar_type,
860                            )
861                        })
862                        .collect();
863                    // NOTE we don't need to remove any extra columns from aggregate.applied_to above because the reduce will do that anyway
864                    let mut reduced =
865                        input.reduce(applied_group_key, applied_aggregates, expected_group_size);
866
867                    // Introduce default values in the case the group key is empty.
868                    if group_key.is_empty() {
869                        reduced = get_outer.lookup::<PlanError>(id_gen, reduced, default)?;
870                    }
871                    reduced
872                }
873                Distinct { input } => {
874                    // Distinct is uncomplicated.
875                    input
876                        .applied_to(id_gen, get_outer, col_map, cte_map, context)?
877                        .distinct()
878                }
879                TopK {
880                    input,
881                    group_key,
882                    order_key,
883                    limit,
884                    offset,
885                    expected_group_size,
886                } => {
887                    // TopK is uncomplicated, except that we must group by the columns of `get_outer` as well.
888                    let mut input =
889                        input.applied_to(id_gen, get_outer.clone(), col_map, cte_map, context)?;
890                    let mut applied_group_key: Vec<_> = (0..get_outer.arity())
891                        .chain(group_key.iter().map(|i| get_outer.arity() + i))
892                        .collect();
893                    let applied_order_key = order_key
894                        .iter()
895                        .map(|column_order| ColumnOrder {
896                            column: column_order.column + get_outer.arity(),
897                            desc: column_order.desc,
898                            nulls_last: column_order.nulls_last,
899                        })
900                        .collect();
901
902                    let old_arity = input.arity();
903
904                    // Lower `limit`, which may introduce new columns if is a correlated subquery.
905                    let mut limit_mir = None;
906                    if let Some(limit) = limit {
907                        limit_mir = Some(
908                            limit
909                                .applied_to(id_gen, col_map, cte_map, &mut input, &None, context)?,
910                        );
911                    }
912
913                    let new_arity = input.arity();
914                    // Extend the key to contain any new columns.
915                    applied_group_key.extend(old_arity..new_arity);
916
917                    let offset = offset
918                        .try_into_literal_int64()
919                        .expect("Should be a Literal by this time")
920                        .try_into()
921                        .expect("Should have checked non-negativity of OFFSET clause already");
922                    let mut result = input.top_k(
923                        applied_group_key,
924                        applied_order_key,
925                        limit_mir,
926                        offset,
927                        expected_group_size,
928                    );
929
930                    // If new columns were added for `limit` we must remove them.
931                    if old_arity != new_arity {
932                        result = result.project((0..old_arity).collect());
933                    }
934
935                    result
936                }
937                Negate { input } => {
938                    // Negate is uncomplicated.
939                    input
940                        .applied_to(id_gen, get_outer, col_map, cte_map, context)?
941                        .negate()
942                }
943                Threshold { input } => {
944                    // Threshold is uncomplicated.
945                    input
946                        .applied_to(id_gen, get_outer, col_map, cte_map, context)?
947                        .threshold()
948                }
949            })
950        })
951    }
952}
953
954impl HirScalarExpr {
955    /// Rewrite `self` into a `mz_expr::ScalarExpr` which can be applied to the modified `inner`.
956    ///
957    /// This method is responsible for decorrelating subqueries in `self` by introducing further columns
958    /// to `inner`, and rewriting `self` to refer to its physical columns (specified by `usize` positions).
959    /// The most complicated logic is for the scalar expressions that involve subqueries, each of which are
960    /// documented in more detail closer to their logic.
961    ///
962    /// This process presumes that `inner` is the result of decorrelation, meaning its first several columns
963    /// may be inherited from outer relations. The `col_map` column map should provide specific offsets where
964    /// each of these references can be found.
965    fn applied_to(
966        self,
967        id_gen: &mut mz_ore::id_gen::IdGen,
968        col_map: &ColumnMap,
969        cte_map: &mut CteMap,
970        inner: &mut MirRelationExpr,
971        subquery_map: &Option<&BTreeMap<HirScalarExpr, usize>>,
972        context: &Context,
973    ) -> Result<MirScalarExpr, PlanError> {
974        maybe_grow(|| {
975            use MirScalarExpr as SS;
976
977            use HirScalarExpr::*;
978
979            if let Some(subquery_map) = subquery_map {
980                if let Some(col) = subquery_map.get(&self) {
981                    return Ok(SS::column(*col));
982                }
983            }
984
985            Ok::<MirScalarExpr, PlanError>(match self {
986                Column(col_ref, name) => SS::Column(col_map.get(&col_ref), name),
987                Literal(row, typ, _name) => SS::Literal(Ok(row), ReprColumnType::from(&typ)),
988                Parameter(_, _name) => {
989                    panic!("cannot decorrelate expression with unbound parameters")
990                }
991                CallUnmaterializable(func, _name) => SS::CallUnmaterializable(func),
992                CallUnary {
993                    func,
994                    expr,
995                    name: _,
996                } => {
997                    let inner =
998                        expr.applied_to(id_gen, col_map, cte_map, inner, subquery_map, context)?;
999                    if context.config.enable_cast_elimination && func.is_eliminable_cast() {
1000                        inner
1001                    } else {
1002                        SS::CallUnary {
1003                            func,
1004                            expr: Box::new(inner),
1005                        }
1006                    }
1007                }
1008                CallBinary {
1009                    func,
1010                    expr1,
1011                    expr2,
1012                    name: _,
1013                } => SS::CallBinary {
1014                    func,
1015                    expr1: Box::new(expr1.applied_to(
1016                        id_gen,
1017                        col_map,
1018                        cte_map,
1019                        inner,
1020                        subquery_map,
1021                        context,
1022                    )?),
1023                    expr2: Box::new(expr2.applied_to(
1024                        id_gen,
1025                        col_map,
1026                        cte_map,
1027                        inner,
1028                        subquery_map,
1029                        context,
1030                    )?),
1031                },
1032                CallVariadic {
1033                    func,
1034                    exprs,
1035                    name: _,
1036                } => SS::call_variadic(
1037                    func,
1038                    exprs
1039                        .into_iter()
1040                        .map(|expr| {
1041                            expr.applied_to(id_gen, col_map, cte_map, inner, subquery_map, context)
1042                        })
1043                        .collect::<Result<_, _>>()?,
1044                ),
1045                If {
1046                    cond,
1047                    then,
1048                    els,
1049                    name,
1050                } => {
1051                    // The `If` case is complicated by the fact that we do not want to
1052                    // apply the `then` or `else` logic to tuples that respectively do
1053                    // not or do pass the `cond` test. Our strategy is to independently
1054                    // decorrelate the `then` and `else` logic, and apply each to tuples
1055                    // that respectively pass and do not pass the `cond` logic (which is
1056                    // executed, and so decorrelated, for all tuples).
1057                    //
1058                    // Informally, we turn the `if` statement into:
1059                    //
1060                    //   let then_case = inner.filter(cond).map(then);
1061                    //   let else_case = inner.filter(!cond).map(else);
1062                    //   return then_case.concat(else_case);
1063                    //
1064                    // We only require this if either expression would result in any
1065                    // computation beyond the expr itself, which we will interpret as
1066                    // "introduces additional columns". In the absence of correlation,
1067                    // we should just retain a `ScalarExpr::If` expression; the inverse
1068                    // transformation as above is complicated to recover after the fact,
1069                    // and we would benefit from not introducing the complexity.
1070
1071                    let inner_arity = inner.arity();
1072                    let cond_expr =
1073                        cond.applied_to(id_gen, col_map, cte_map, inner, subquery_map, context)?;
1074
1075                    // Defensive copies, in case we mangle these in decorrelation.
1076                    let inner_clone = inner.clone();
1077                    let then_clone = then.clone();
1078                    let else_clone = els.clone();
1079
1080                    let cond_arity = inner.arity();
1081                    let then_expr =
1082                        then.applied_to(id_gen, col_map, cte_map, inner, subquery_map, context)?;
1083                    let else_expr =
1084                        els.applied_to(id_gen, col_map, cte_map, inner, subquery_map, context)?;
1085
1086                    if cond_arity == inner.arity() {
1087                        // If no additional columns were added, we simply return the
1088                        // `If` variant with the updated expressions.
1089                        SS::If {
1090                            cond: Box::new(cond_expr),
1091                            then: Box::new(then_expr),
1092                            els: Box::new(else_expr),
1093                        }
1094                    } else {
1095                        // If columns were added, we need a more careful approach, as
1096                        // described above. First, we need to de-correlate each of
1097                        // the two expressions independently, and apply their cases
1098                        // as `MirRelationExpr::Map` operations.
1099
1100                        *inner = inner_clone.let_in(id_gen, |id_gen, get_inner| {
1101                            // Restrict to records satisfying `cond_expr` and apply `then` as a map.
1102                            let mut then_inner = get_inner.clone().filter(vec![cond_expr.clone()]);
1103                            let then_expr = then_clone.applied_to(
1104                                id_gen,
1105                                col_map,
1106                                cte_map,
1107                                &mut then_inner,
1108                                subquery_map,
1109                                context,
1110                            )?;
1111                            let then_arity = then_inner.arity();
1112                            then_inner = then_inner
1113                                .map_one(then_expr)
1114                                .project((0..inner_arity).chain(Some(then_arity)).collect());
1115
1116                            // Restrict to records not satisfying `cond_expr` and apply `els` as a map.
1117                            let mut else_inner = get_inner.filter(vec![SS::call_variadic(
1118                                variadic::Or,
1119                                vec![
1120                                    cond_expr.clone().call_binary(SS::literal_false(), func::Eq),
1121                                    cond_expr.clone().call_is_null(),
1122                                ],
1123                            )]);
1124                            let else_expr = else_clone.applied_to(
1125                                id_gen,
1126                                col_map,
1127                                cte_map,
1128                                &mut else_inner,
1129                                subquery_map,
1130                                context,
1131                            )?;
1132                            let else_arity = else_inner.arity();
1133                            else_inner = else_inner
1134                                .map_one(else_expr)
1135                                .project((0..inner_arity).chain(Some(else_arity)).collect());
1136
1137                            // concatenate the two results.
1138                            Ok::<MirRelationExpr, PlanError>(then_inner.union(else_inner))
1139                        })?;
1140
1141                        SS::Column(inner_arity, name)
1142                    }
1143                }
1144
1145                // Subqueries!
1146                // These are surprisingly subtle. Things to be careful of:
1147
1148                // Anything in the subquery that cares about row counts (Reduce/Distinct/Negate/Threshold) must not:
1149                // * change the row counts of the outer query
1150                // * accidentally compute its own value using the row counts of the outer query
1151                // Use `branch` to calculate the subquery once for each __distinct__ key in the outer
1152                // query and then join the answers back on to the original rows of the outer query.
1153
1154                // When the subquery would return 0 rows for some row in the outer query, `subquery.applied_to(get_inner)` will not have any corresponding row.
1155                // Use `lookup` if you need to add default values for cases when the subquery returns 0 rows.
1156                Exists(expr, name) => {
1157                    let apply_requires_distinct_outer = true;
1158                    *inner = apply_existential_subquery(
1159                        id_gen,
1160                        inner.take_dangerous(),
1161                        col_map,
1162                        cte_map,
1163                        *expr,
1164                        apply_requires_distinct_outer,
1165                        context,
1166                    )?;
1167                    SS::Column(inner.arity() - 1, name)
1168                }
1169
1170                Select(expr, name) => {
1171                    let apply_requires_distinct_outer = true;
1172                    *inner = apply_scalar_subquery(
1173                        id_gen,
1174                        inner.take_dangerous(),
1175                        col_map,
1176                        cte_map,
1177                        *expr,
1178                        apply_requires_distinct_outer,
1179                        context,
1180                    )?;
1181                    SS::Column(inner.arity() - 1, name)
1182                }
1183                Windowing(expr, _name) => {
1184                    let partition_by = expr.partition_by;
1185                    let order_by = expr.order_by;
1186
1187                    // argument lowering for scalar window functions
1188                    // (We need to specify the & _ in the arguments because of this problem:
1189                    // https://users.rust-lang.org/t/the-implementation-of-fnonce-is-not-general-enough/72141/3 )
1190                    let scalar_lower_args =
1191                        |_id_gen: &mut _,
1192                         _col_map: &_,
1193                         _cte_map: &mut _,
1194                         _get_inner: &mut _,
1195                         _subquery_map: &Option<&_>,
1196                         order_by_mir: Vec<MirScalarExpr>,
1197                         original_row_record,
1198                         original_row_record_type: SqlScalarType| {
1199                            let agg_input = MirScalarExpr::call_variadic(
1200                                variadic::ListCreate {
1201                                    elem_type: original_row_record_type.clone(),
1202                                },
1203                                vec![original_row_record],
1204                            );
1205                            let mut agg_input = vec![agg_input];
1206                            agg_input.extend(order_by_mir.clone());
1207                            let agg_input = MirScalarExpr::call_variadic(
1208                                variadic::RecordCreate {
1209                                    field_names: (0..agg_input.len())
1210                                        .map(|_| ColumnName::from(UNKNOWN_COLUMN_NAME))
1211                                        .collect_vec(),
1212                                },
1213                                agg_input,
1214                            );
1215                            let list_type = SqlScalarType::List {
1216                                element_type: Box::new(original_row_record_type),
1217                                custom_id: None,
1218                            };
1219                            let agg_input_type = SqlScalarType::Record {
1220                                fields: std::iter::once(&list_type)
1221                                    .map(|t| {
1222                                        (
1223                                            ColumnName::from(UNKNOWN_COLUMN_NAME),
1224                                            t.clone().nullable(false),
1225                                        )
1226                                    })
1227                                    .collect(),
1228                                custom_id: None,
1229                            }
1230                            .nullable(false);
1231
1232                            Ok((agg_input, agg_input_type))
1233                        };
1234
1235                    // argument lowering for value window functions and aggregate window functions
1236                    let value_or_aggr_lower_args = |hir_encoded_args: Box<HirScalarExpr>| {
1237                        |id_gen: &mut _,
1238                         col_map: &_,
1239                         cte_map: &mut _,
1240                         get_inner: &mut _,
1241                         subquery_map: &Option<&_>,
1242                         order_by_mir: Vec<MirScalarExpr>,
1243                         original_row_record,
1244                         original_row_record_type| {
1245                            // Creates [((OriginalRow, EncodedArgs), OrderByExprs...)]
1246
1247                            // Compute the encoded args for all rows
1248                            let mir_encoded_args = hir_encoded_args.applied_to(
1249                                id_gen,
1250                                col_map,
1251                                cte_map,
1252                                get_inner,
1253                                subquery_map,
1254                                context,
1255                            )?;
1256                            let mir_encoded_args_type = mir_encoded_args
1257                                .sql_typ(&get_inner.sql_typ().column_types)
1258                                .scalar_type;
1259
1260                            // Build a new record that has two fields:
1261                            // 1. the original row in a record
1262                            // 2. the encoded args (which can be either a single value, or a record
1263                            //    if the window function has multiple arguments, such as `lag`)
1264                            let fn_input_record_fields: Box<[_]> =
1265                                [original_row_record_type, mir_encoded_args_type]
1266                                    .iter()
1267                                    .map(|t| {
1268                                        (
1269                                            ColumnName::from(UNKNOWN_COLUMN_NAME),
1270                                            t.clone().nullable(false),
1271                                        )
1272                                    })
1273                                    .collect();
1274                            let fn_input_record = MirScalarExpr::call_variadic(
1275                                variadic::RecordCreate {
1276                                    field_names: fn_input_record_fields
1277                                        .iter()
1278                                        .map(|(n, _)| n.clone())
1279                                        .collect_vec(),
1280                                },
1281                                vec![original_row_record, mir_encoded_args],
1282                            );
1283                            let fn_input_record_type = SqlScalarType::Record {
1284                                fields: fn_input_record_fields,
1285                                custom_id: None,
1286                            }
1287                            .nullable(false);
1288
1289                            // Build a new record with the record above + the ORDER BY exprs
1290                            // This follows the standard encoding of ORDER BY exprs used by aggregate functions
1291                            let mut agg_input = vec![fn_input_record];
1292                            agg_input.extend(order_by_mir.clone());
1293                            let agg_input = MirScalarExpr::call_variadic(
1294                                variadic::RecordCreate {
1295                                    field_names: (0..agg_input.len())
1296                                        .map(|_| ColumnName::from(UNKNOWN_COLUMN_NAME))
1297                                        .collect_vec(),
1298                                },
1299                                agg_input,
1300                            );
1301
1302                            let agg_input_type = SqlScalarType::Record {
1303                                fields: [(
1304                                    ColumnName::from(UNKNOWN_COLUMN_NAME),
1305                                    fn_input_record_type.nullable(false),
1306                                )]
1307                                .into(),
1308                                custom_id: None,
1309                            }
1310                            .nullable(false);
1311
1312                            Ok((agg_input, agg_input_type))
1313                        }
1314                    };
1315
1316                    match expr.func {
1317                        WindowExprType::Scalar(scalar_window_expr) => {
1318                            let mir_aggr_func = scalar_window_expr.into_expr();
1319                            Self::window_func_applied_to(
1320                                id_gen,
1321                                col_map,
1322                                cte_map,
1323                                inner,
1324                                subquery_map,
1325                                partition_by,
1326                                order_by,
1327                                mir_aggr_func,
1328                                scalar_lower_args,
1329                                context,
1330                            )?
1331                        }
1332                        WindowExprType::Value(value_window_expr) => {
1333                            let (hir_encoded_args, mir_aggr_func) = value_window_expr.into_expr();
1334
1335                            Self::window_func_applied_to(
1336                                id_gen,
1337                                col_map,
1338                                cte_map,
1339                                inner,
1340                                subquery_map,
1341                                partition_by,
1342                                order_by,
1343                                mir_aggr_func,
1344                                value_or_aggr_lower_args(hir_encoded_args),
1345                                context,
1346                            )?
1347                        }
1348                        WindowExprType::Aggregate(aggr_window_expr) => {
1349                            let (hir_encoded_args, mir_aggr_func) = aggr_window_expr.into_expr();
1350
1351                            Self::window_func_applied_to(
1352                                id_gen,
1353                                col_map,
1354                                cte_map,
1355                                inner,
1356                                subquery_map,
1357                                partition_by,
1358                                order_by,
1359                                mir_aggr_func,
1360                                value_or_aggr_lower_args(hir_encoded_args),
1361                                context,
1362                            )?
1363                        }
1364                    }
1365                }
1366            })
1367        })
1368    }
1369
1370    fn window_func_applied_to<F>(
1371        id_gen: &mut mz_ore::id_gen::IdGen,
1372        col_map: &ColumnMap,
1373        cte_map: &mut CteMap,
1374        inner: &mut MirRelationExpr,
1375        subquery_map: &Option<&BTreeMap<HirScalarExpr, usize>>,
1376        partition_by: Vec<HirScalarExpr>,
1377        order_by: Vec<HirScalarExpr>,
1378        mir_aggr_func: AggregateFunc,
1379        lower_args: F,
1380        context: &Context,
1381    ) -> Result<MirScalarExpr, PlanError>
1382    where
1383        F: FnOnce(
1384            &mut mz_ore::id_gen::IdGen,
1385            &ColumnMap,
1386            &mut CteMap,
1387            &mut MirRelationExpr,
1388            &Option<&BTreeMap<HirScalarExpr, usize>>,
1389            Vec<MirScalarExpr>,
1390            MirScalarExpr,
1391            SqlScalarType,
1392        ) -> Result<(MirScalarExpr, SqlColumnType), PlanError>,
1393    {
1394        // Example MIRs for a window function (specifically, a window aggregation):
1395        //
1396        // CREATE TABLE t7(x INT, y INT);
1397        //
1398        // explain decorrelated plan for select sum(x*y) over (partition by x+y order by x-y, x/y) from t7;
1399        //
1400        // Decorrelated Plan
1401        // Project (#3)
1402        //   Map (#2)
1403        //     Project (#3..=#5)
1404        //       Map (record_get[0](record_get[1](#2)), record_get[1](record_get[1](#2)), record_get[0](#2))
1405        //         FlatMap unnest_list(#1)
1406        //           Reduce group_by=[#2] aggregates=[window_agg[sum order_by=[#0 asc nulls_last, #1 asc nulls_last]](row(row(row(#0, #1), (#0 * #1)), (#0 - #1), (#0 / #1)))]
1407        //             Map ((#0 + #1))
1408        //               CrossJoin
1409        //                 Constant
1410        //                   - ()
1411        //                 Get materialize.public.t7
1412        //
1413        // The same query after optimizations:
1414        //
1415        // explain select sum(x*y) over (partition by x+y order by x-y, x/y) from t7;
1416        //
1417        // Optimized Plan
1418        // Explained Query:
1419        //   Project (#2)
1420        //     Map (record_get[0](#1))
1421        //       FlatMap unnest_list(#0)
1422        //         Project (#1)
1423        //           Reduce group_by=[(#0 + #1)] aggregates=[window_agg[sum order_by=[#0 asc nulls_last, #1 asc nulls_last]](row(row(row(#0, #1), (#0 * #1)), (#0 - #1), (#0 / #1)))]
1424        //             ReadStorage materialize.public.t7
1425        //
1426        // The `row(row(row(...), ...), ...)` stuff means the following:
1427        // `row(row(row(<original row>), <arguments to window function>), <order by values>...)`
1428        //   - The <arguments to window function> can be either a single value or itself a
1429        //     `row` if there are multiple arguments.
1430        //   - The <order by values> are _not_ wrapped in a `row`, even if there are more than one
1431        //     ORDER BY columns.
1432        //   - The <original row> currently always captures the entire original row. This should
1433        //     improve when we make `ProjectionPushdown` smarter, see
1434        //     https://github.com/MaterializeInc/database-issues/issues/5090
1435        //
1436        // TODO:
1437        // We should probably introduce some dedicated Datum constructor functions instead of `row`
1438        // to make MIR plans and MIR construction/manipulation code more readable. Additionally, we
1439        // might even introduce dedicated Datum enum variants, so that the rendering code also
1440        // becomes more readable (and possibly slightly more performant).
1441
1442        *inner = inner
1443            .take_dangerous()
1444            .let_in(id_gen, |id_gen, mut get_inner| {
1445                let order_by_mir = order_by
1446                    .into_iter()
1447                    .map(|o| {
1448                        o.applied_to(
1449                            id_gen,
1450                            col_map,
1451                            cte_map,
1452                            &mut get_inner,
1453                            subquery_map,
1454                            context,
1455                        )
1456                    })
1457                    .collect::<Result<Vec<_>, _>>()?;
1458
1459                // Record input arity here so that any group_keys that need to mutate get_inner
1460                // don't add those columns to the aggregate input.
1461                let input_type = get_inner.sql_typ();
1462                let input_arity = input_type.arity();
1463                // The reduction that computes the window function must be keyed on the columns
1464                // from the outer context, plus the expressions in the partition key. The current
1465                // subquery will be 'executed' for every distinct row from the outer context so
1466                // by putting the outer columns in the grouping key we isolate each re-execution.
1467                let mut group_key = col_map
1468                    .inner
1469                    .iter()
1470                    .map(|(_, outer_col)| *outer_col)
1471                    .sorted()
1472                    .collect_vec();
1473                for p in partition_by {
1474                    let key = p.applied_to(
1475                        id_gen,
1476                        col_map,
1477                        cte_map,
1478                        &mut get_inner,
1479                        subquery_map,
1480                        context,
1481                    )?;
1482                    if let MirScalarExpr::Column(c, _name) = key {
1483                        group_key.push(c);
1484                    } else {
1485                        get_inner = get_inner.map_one(key);
1486                        group_key.push(get_inner.arity() - 1);
1487                    }
1488                }
1489
1490                get_inner.let_in(id_gen, |id_gen, mut get_inner| {
1491                    // Original columns of the relation
1492                    let fields: Box<_> = input_type
1493                        .column_types
1494                        .iter()
1495                        .take(input_arity)
1496                        .map(|t| (ColumnName::from(UNKNOWN_COLUMN_NAME), t.clone()))
1497                        .collect();
1498
1499                    // Original row made into a record
1500                    let original_row_record = MirScalarExpr::call_variadic(
1501                        variadic::RecordCreate {
1502                            field_names: fields.iter().map(|(name, _)| name.clone()).collect_vec(),
1503                        },
1504                        (0..input_arity).map(MirScalarExpr::column).collect_vec(),
1505                    );
1506                    let original_row_record_type = SqlScalarType::Record {
1507                        fields,
1508                        custom_id: None,
1509                    };
1510
1511                    let (agg_input, agg_input_type) = lower_args(
1512                        id_gen,
1513                        col_map,
1514                        cte_map,
1515                        &mut get_inner,
1516                        subquery_map,
1517                        order_by_mir,
1518                        original_row_record,
1519                        original_row_record_type,
1520                    )?;
1521
1522                    let aggregate = mz_expr::AggregateExpr {
1523                        func: mir_aggr_func,
1524                        expr: agg_input,
1525                        distinct: false,
1526                    };
1527
1528                    // Actually call reduce with the window function
1529                    // The output of the aggregation function should be a list of tuples that has
1530                    // the result in the first position, and the original row in the second position
1531                    let mut reduce = get_inner
1532                        .reduce(group_key.clone(), vec![aggregate.clone()], None)
1533                        .flat_map(
1534                            mz_expr::TableFunc::UnnestList {
1535                                el_typ: aggregate
1536                                    .func
1537                                    .output_sql_type(agg_input_type)
1538                                    .scalar_type
1539                                    .unwrap_list_element_type()
1540                                    .clone(),
1541                            },
1542                            vec![MirScalarExpr::column(group_key.len())],
1543                        );
1544                    let record_col = reduce.arity() - 1;
1545
1546                    // Unpack the record output by the window function
1547                    for c in 0..input_arity {
1548                        reduce = reduce.take_dangerous().map_one(MirScalarExpr::CallUnary {
1549                            func: mz_expr::UnaryFunc::RecordGet(mz_expr::func::RecordGet(c)),
1550                            expr: Box::new(MirScalarExpr::CallUnary {
1551                                func: mz_expr::UnaryFunc::RecordGet(mz_expr::func::RecordGet(1)),
1552                                expr: Box::new(MirScalarExpr::column(record_col)),
1553                            }),
1554                        });
1555                    }
1556
1557                    // Append the column with the result of the window function.
1558                    reduce = reduce.take_dangerous().map_one(MirScalarExpr::CallUnary {
1559                        func: mz_expr::UnaryFunc::RecordGet(mz_expr::func::RecordGet(0)),
1560                        expr: Box::new(MirScalarExpr::column(record_col)),
1561                    });
1562
1563                    let agg_col = record_col + 1 + input_arity;
1564                    Ok::<_, PlanError>(reduce.project((record_col + 1..agg_col + 1).collect_vec()))
1565                })
1566            })?;
1567        Ok(MirScalarExpr::column(inner.arity() - 1))
1568    }
1569
1570    /// Applies the subqueries in the given list of scalar expressions to every distinct
1571    /// value of the given relation and returns a join of the given relation with all
1572    /// the subqueries found, and the mapping of scalar expressions with columns projected
1573    /// by the returned join that will hold their results.
1574    fn lower_subqueries(
1575        exprs: &[Self],
1576        id_gen: &mut mz_ore::id_gen::IdGen,
1577        col_map: &ColumnMap,
1578        cte_map: &mut CteMap,
1579        inner: MirRelationExpr,
1580        context: &Context,
1581    ) -> Result<(MirRelationExpr, BTreeMap<HirScalarExpr, usize>), PlanError> {
1582        let mut subquery_map = BTreeMap::new();
1583        let output = inner.let_in(id_gen, |id_gen, get_inner| {
1584            let mut subqueries = Vec::new();
1585            let distinct_inner = get_inner.clone().distinct();
1586            for expr in exprs.iter() {
1587                expr.visit_pre_post(
1588                    &mut |e| match e {
1589                        // For simplicity, subqueries within a conditional statement will be
1590                        // lowered when lowering the conditional expression.
1591                        HirScalarExpr::If { .. } => Some(vec![]),
1592                        _ => None,
1593                    },
1594                    &mut |e| match e {
1595                        HirScalarExpr::Select(expr, _name) => {
1596                            let apply_requires_distinct_outer = false;
1597                            let subquery = apply_scalar_subquery(
1598                                id_gen,
1599                                distinct_inner.clone(),
1600                                col_map,
1601                                cte_map,
1602                                (**expr).clone(),
1603                                apply_requires_distinct_outer,
1604                                context,
1605                            )
1606                            .unwrap();
1607
1608                            subqueries.push((e.clone(), subquery));
1609                        }
1610                        HirScalarExpr::Exists(expr, _name) => {
1611                            let apply_requires_distinct_outer = false;
1612                            let subquery = apply_existential_subquery(
1613                                id_gen,
1614                                distinct_inner.clone(),
1615                                col_map,
1616                                cte_map,
1617                                (**expr).clone(),
1618                                apply_requires_distinct_outer,
1619                                context,
1620                            )
1621                            .unwrap();
1622                            subqueries.push((e.clone(), subquery));
1623                        }
1624                        _ => {}
1625                    },
1626                )?;
1627            }
1628
1629            if subqueries.is_empty() {
1630                Ok::<MirRelationExpr, PlanError>(get_inner)
1631            } else {
1632                let inner_arity = get_inner.arity();
1633                let mut total_arity = inner_arity;
1634                let mut join_inputs = vec![get_inner];
1635                let mut join_input_arities = vec![inner_arity];
1636                for (expr, subquery) in subqueries.into_iter() {
1637                    // Avoid lowering duplicated subqueries
1638                    if !subquery_map.contains_key(&expr) {
1639                        let subquery_arity = subquery.arity();
1640                        assert_eq!(subquery_arity, inner_arity + 1);
1641                        join_inputs.push(subquery);
1642                        join_input_arities.push(subquery_arity);
1643                        total_arity += subquery_arity;
1644
1645                        // Column with the value of the subquery
1646                        subquery_map.insert(expr, total_arity - 1);
1647                    }
1648                }
1649                // Each subquery projects all the columns of the outer context (distinct_inner)
1650                // plus 1 column, containing the result of the subquery. Those columns must be
1651                // joined with the outer/main relation (get_inner).
1652                let input_mapper =
1653                    mz_expr::JoinInputMapper::new_from_input_arities(join_input_arities);
1654                let equivalences = (0..inner_arity)
1655                    .map(|col| {
1656                        join_inputs
1657                            .iter()
1658                            .enumerate()
1659                            .map(|(input, _)| {
1660                                MirScalarExpr::column(input_mapper.map_column_to_global(col, input))
1661                            })
1662                            .collect_vec()
1663                    })
1664                    .collect_vec();
1665                Ok(MirRelationExpr::join_scalars(join_inputs, equivalences))
1666            }
1667        })?;
1668        Ok((output, subquery_map))
1669    }
1670
1671    /// Rewrites `self` into a `mz_expr::ScalarExpr`.
1672    ///
1673    /// Returns an _internal_ error if the expression contains
1674    /// - a subquery
1675    /// - a column reference to an outer level
1676    /// - a parameter
1677    /// - a window function call
1678    ///
1679    /// Should succeed if [`HirScalarExpr::is_constant`] would return true on `self`.
1680    ///
1681    /// Set `enable_cast_elimination` to remove casts that are noops in MIR.
1682    pub fn lower_uncorrelated<C: Into<Config>>(
1683        self,
1684        config: C,
1685    ) -> Result<MirScalarExpr, PlanError> {
1686        let config = config.into();
1687
1688        use MirScalarExpr as SS;
1689
1690        use HirScalarExpr::*;
1691
1692        Ok(match self {
1693            Column(ColumnRef { level: 0, column }, name) => SS::Column(column, name),
1694            Literal(datum, typ, _name) => SS::Literal(Ok(datum), ReprColumnType::from(&typ)),
1695            CallUnmaterializable(func, _name) => SS::CallUnmaterializable(func),
1696            CallUnary {
1697                func,
1698                expr,
1699                name: _,
1700            } => {
1701                let inner = expr.lower_uncorrelated(config)?;
1702
1703                if config.enable_cast_elimination && func.is_eliminable_cast() {
1704                    inner
1705                } else {
1706                    SS::CallUnary {
1707                        func,
1708                        expr: Box::new(inner),
1709                    }
1710                }
1711            }
1712            CallBinary {
1713                func,
1714                expr1,
1715                expr2,
1716                name: _,
1717            } => SS::CallBinary {
1718                func,
1719                expr1: Box::new(expr1.lower_uncorrelated(config)?),
1720                expr2: Box::new(expr2.lower_uncorrelated(config)?),
1721            },
1722            CallVariadic {
1723                func,
1724                exprs,
1725                name: _,
1726            } => SS::call_variadic(
1727                func,
1728                exprs
1729                    .into_iter()
1730                    .map(|expr| expr.lower_uncorrelated(config))
1731                    .collect::<Result<_, _>>()?,
1732            ),
1733            If {
1734                cond,
1735                then,
1736                els,
1737                name: _,
1738            } => SS::If {
1739                cond: Box::new(cond.lower_uncorrelated(config)?),
1740                then: Box::new(then.lower_uncorrelated(config)?),
1741                els: Box::new(els.lower_uncorrelated(config)?),
1742            },
1743            Select { .. } | Exists { .. } | Parameter(..) | Column(..) | Windowing(..) => {
1744                sql_bail!(
1745                    "Internal error: unexpected HirScalarExpr in lower_uncorrelated: {:?}",
1746                    self
1747                );
1748            }
1749        })
1750    }
1751}
1752
1753/// Prepare to apply `inner` to `outer`. Note that `inner` is a correlated (SQL)
1754/// expression, while `outer` is a non-correlated (dataflow) expression. `inner`
1755/// will, in effect, be executed once for every distinct row in `outer`, and the
1756/// results will be joined with `outer`. Note that columns in `outer` that are
1757/// not depended upon by `inner` are thrown away before the distinct, so that we
1758/// don't perform needless computation of `inner`.
1759///
1760/// `branch` will inspect the contents of `inner` to determine whether `inner`
1761/// is not multiplicity sensitive (roughly, contains only maps, filters,
1762/// projections, and calls to table functions). If it is not multiplicity
1763/// sensitive, `branch` will *not* distinctify outer. If this is problematic,
1764/// e.g. because the `apply` callback itself introduces multiplicity-sensitive
1765/// operations that were not present in `inner`, then set
1766/// `apply_requires_distinct_outer` to ensure that `branch` chooses the plan
1767/// that distinctifies `outer`.
1768///
1769/// The caller must supply the `apply` function that applies the rewritten
1770/// `inner` to `outer`.
1771fn branch<F>(
1772    id_gen: &mut mz_ore::id_gen::IdGen,
1773    outer: MirRelationExpr,
1774    col_map: &ColumnMap,
1775    cte_map: &mut CteMap,
1776    inner: HirRelationExpr,
1777    apply_requires_distinct_outer: bool,
1778    context: &Context,
1779    apply: F,
1780) -> Result<MirRelationExpr, PlanError>
1781where
1782    F: FnOnce(
1783        &mut mz_ore::id_gen::IdGen,
1784        HirRelationExpr,
1785        MirRelationExpr,
1786        &ColumnMap,
1787        &mut CteMap,
1788        &Context,
1789    ) -> Result<MirRelationExpr, PlanError>,
1790{
1791    // TODO: It would be nice to have a version of this code w/o optimizations,
1792    // at the least for purposes of understanding. It was difficult for one reader
1793    // to understand the required properties of `outer` and `col_map`.
1794
1795    // If the inner expression is sufficiently simple, it is safe to apply it
1796    // *directly* to outer, rather than applying it to the distinctified key
1797    // (see below).
1798    //
1799    // As an example, consider the following two queries:
1800    //
1801    //     CREATE TABLE t (a int, b int);
1802    //     SELECT a, series FROM t, generate_series(1, t.b) series;
1803    //
1804    // The "simple" path for the `SELECT` yields
1805    //
1806    //     %0 =
1807    //     | Get t
1808    //     | FlatMap generate_series(1, #1)
1809    //
1810    // while the non-simple path yields:
1811    //
1812    //    %0 =
1813    //    | Get t
1814    //
1815    //    %1 =
1816    //    | Get t
1817    //    | Distinct group=(#1)
1818    //    | FlatMap generate_series(1, #0)
1819    //
1820    //    %2 =
1821    //    | LeftJoin %1 %2 (= #1 #2)
1822    //
1823    // There is a tradeoff here: the simple plan is stateless, but the non-
1824    // simple plan may do (much) less computation if there are only a few
1825    // distinct values of `t.b`.
1826    //
1827    // We apply a very simple heuristic here and take the simple path if `inner`
1828    // contains only maps, filters, projections, and calls to table functions.
1829    // The intuition is that straightforward usage of table functions should
1830    // take the simple path, while everything else should not. (In theory we
1831    // think this transformation is valid as long as `inner` does not contain a
1832    // Reduce, Distinct, or TopK node, but it is not always an optimization in
1833    // the general case.)
1834    //
1835    // TODO(benesch): this should all be handled by a proper optimizer, but
1836    // detecting the moment of decorrelation in the optimizer right now is too
1837    // hard.
1838    let mut is_simple = true;
1839    #[allow(deprecated)]
1840    inner.visit(0, &mut |expr, _| match expr {
1841        HirRelationExpr::Constant { .. }
1842        | HirRelationExpr::Project { .. }
1843        | HirRelationExpr::Map { .. }
1844        | HirRelationExpr::Filter { .. }
1845        | HirRelationExpr::CallTable { .. } => (),
1846        _ => is_simple = false,
1847    });
1848    if is_simple && !apply_requires_distinct_outer {
1849        let new_col_map = col_map.enter_scope(outer.arity() - col_map.len());
1850        return outer.let_in(id_gen, |id_gen, get_outer| {
1851            apply(id_gen, inner, get_outer, &new_col_map, cte_map, context)
1852        });
1853    }
1854
1855    // The key consists of the columns from the outer expression upon which the
1856    // inner relation depends. We discover these dependencies by walking the
1857    // inner relation expression and looking for column references whose level
1858    // escapes inner.
1859    //
1860    // At the end of this process, `key` contains the decorrelated position of
1861    // each outer column, according to the passed-in `col_map`, and
1862    // `new_col_map` maps each outer column to its new ordinal position in key.
1863    let mut outer_cols = BTreeSet::new();
1864    #[allow(deprecated)]
1865    inner.visit_columns(0, &mut |depth, col| {
1866        // Test if the column reference escapes the subquery.
1867        if col.level > depth {
1868            outer_cols.insert(ColumnRef {
1869                level: col.level - depth,
1870                column: col.column,
1871            });
1872        }
1873    });
1874    // Collect all the outer columns referenced by any CTE referenced by
1875    // the inner relation.
1876    #[allow(deprecated)]
1877    inner.visit(0, &mut |e, _| match e {
1878        HirRelationExpr::Get {
1879            id: mz_expr::Id::Local(id),
1880            ..
1881        } => {
1882            if let Some(cte_desc) = cte_map.get(id) {
1883                let cte_outer_arity = cte_desc.outer_relation.arity();
1884                outer_cols.extend(
1885                    col_map
1886                        .inner
1887                        .iter()
1888                        .filter(|(_, position)| **position < cte_outer_arity)
1889                        .map(|(c, _)| {
1890                            // `col_map` maps column references to column positions in
1891                            // `outer`'s projection.
1892                            // `outer_cols` is meant to contain the external column
1893                            // references in `inner`.
1894                            // Since `inner` defines a new scope, any column reference
1895                            // in `col_map` is one level deeper when seen from within
1896                            // `inner`, hence the +1.
1897                            ColumnRef {
1898                                level: c.level + 1,
1899                                column: c.column,
1900                            }
1901                        }),
1902                );
1903            }
1904        }
1905        HirRelationExpr::Let { id, .. } => {
1906            // Note: if ID uniqueness is not guaranteed, we can't use `visit` since
1907            // we would need to remove the old CTE with the same ID temporarily while
1908            // traversing the definition of the new CTE under the same ID.
1909            assert!(!cte_map.contains_key(id));
1910        }
1911        _ => {}
1912    });
1913    let mut new_col_map = BTreeMap::new();
1914    let mut key = vec![];
1915    for col in outer_cols {
1916        new_col_map.insert(col, key.len());
1917        key.push(col_map.get(&ColumnRef {
1918            // Note: `outer_cols` contains the external column references within `inner`.
1919            // We must compensate for `inner`'s scope when translating column references
1920            // as seen within `inner` to column references as seen from `outer`'s context,
1921            // hence the -1.
1922            level: col.level - 1,
1923            column: col.column,
1924        }));
1925    }
1926    let new_col_map = ColumnMap::new(new_col_map);
1927    outer.let_in(id_gen, |id_gen, get_outer| {
1928        let keyed_outer = if key.is_empty() {
1929            // Don't depend on outer at all if the branch is not correlated,
1930            // which yields vastly better query plans. Note that this is a bit
1931            // weird in that the branch will be computed even if outer has no
1932            // rows, whereas if it had been correlated it would not (and *could*
1933            // not) have been computed if outer had no rows, but the callers of
1934            // this function don't mind these somewhat-weird semantics.
1935            MirRelationExpr::constant(vec![vec![]], ReprRelationType::new(vec![]))
1936        } else {
1937            get_outer.clone().distinct_by(key.clone())
1938        };
1939        keyed_outer.let_in(id_gen, |id_gen, get_keyed_outer| {
1940            let oa = get_outer.arity();
1941            let branch = apply(
1942                id_gen,
1943                inner,
1944                get_keyed_outer,
1945                &new_col_map,
1946                cte_map,
1947                context,
1948            )?;
1949            let ba = branch.arity();
1950            let joined = MirRelationExpr::join(
1951                vec![get_outer.clone(), branch],
1952                key.iter()
1953                    .enumerate()
1954                    .map(|(i, &k)| vec![(0, k), (1, i)])
1955                    .collect(),
1956            )
1957            // throw away the right-hand copy of the key we just joined on
1958            .project((0..oa).chain((oa + key.len())..(oa + ba)).collect());
1959            Ok(joined)
1960        })
1961    })
1962}
1963
1964fn apply_scalar_subquery(
1965    id_gen: &mut mz_ore::id_gen::IdGen,
1966    outer: MirRelationExpr,
1967    col_map: &ColumnMap,
1968    cte_map: &mut CteMap,
1969    scalar_subquery: HirRelationExpr,
1970    apply_requires_distinct_outer: bool,
1971    context: &Context,
1972) -> Result<MirRelationExpr, PlanError> {
1973    branch(
1974        id_gen,
1975        outer,
1976        col_map,
1977        cte_map,
1978        scalar_subquery,
1979        apply_requires_distinct_outer,
1980        context,
1981        |id_gen, expr, get_inner, col_map, cte_map, context| {
1982            // compute for every row in get_inner
1983            let select = expr.applied_to(id_gen, get_inner.clone(), col_map, cte_map, context)?;
1984            let col_type = select.sql_typ().column_types.into_last();
1985
1986            let inner_arity = get_inner.arity();
1987            // We must determine a count for each `get_inner` prefix,
1988            // and report an error if that count exceeds one.
1989            let guarded = select.let_in(id_gen, |_id_gen, get_select| {
1990                // Count for each `get_inner` prefix.
1991                let counts = get_select.clone().reduce(
1992                    (0..inner_arity).collect::<Vec<_>>(),
1993                    vec![mz_expr::AggregateExpr {
1994                        func: mz_expr::AggregateFunc::Count,
1995                        expr: MirScalarExpr::literal_true(),
1996                        distinct: false,
1997                    }],
1998                    None,
1999                );
2000
2001                let use_guard = context.config.enable_guard_subquery_tablefunc;
2002
2003                // Errors should result from counts > 1.
2004                let errors = if use_guard {
2005                    counts
2006                        .flat_map(
2007                            mz_expr::TableFunc::GuardSubquerySize {
2008                                column_type: col_type.clone().scalar_type,
2009                            },
2010                            vec![MirScalarExpr::column(inner_arity)],
2011                        )
2012                        .project(
2013                            (0..inner_arity)
2014                                .chain(Some(inner_arity + 1))
2015                                .collect::<Vec<_>>(),
2016                        )
2017                } else {
2018                    counts
2019                        .filter(vec![MirScalarExpr::column(inner_arity).call_binary(
2020                            MirScalarExpr::literal_ok(Datum::Int64(1), ReprScalarType::Int64),
2021                            func::Gt,
2022                        )])
2023                        .project((0..inner_arity).collect::<Vec<_>>())
2024                        .map_one(MirScalarExpr::literal(
2025                            Err(mz_expr::EvalError::MultipleRowsFromSubquery),
2026                            ReprScalarType::from(&col_type.scalar_type),
2027                        ))
2028                };
2029                // Return `get_select` and any errors added in.
2030                Ok::<_, PlanError>(get_select.union(errors))
2031            })?;
2032            // append Null to anything that didn't return any rows
2033            let default = vec![(Datum::Null, ReprScalarType::from(&col_type.scalar_type))];
2034            get_inner.lookup(id_gen, guarded, default)
2035        },
2036    )
2037}
2038
2039fn apply_existential_subquery(
2040    id_gen: &mut mz_ore::id_gen::IdGen,
2041    outer: MirRelationExpr,
2042    col_map: &ColumnMap,
2043    cte_map: &mut CteMap,
2044    subquery_expr: HirRelationExpr,
2045    apply_requires_distinct_outer: bool,
2046    context: &Context,
2047) -> Result<MirRelationExpr, PlanError> {
2048    branch(
2049        id_gen,
2050        outer,
2051        col_map,
2052        cte_map,
2053        subquery_expr,
2054        apply_requires_distinct_outer,
2055        context,
2056        |id_gen, expr, get_inner, col_map, cte_map, context| {
2057            let exists = expr
2058                // compute for every row in get_inner
2059                .applied_to(id_gen, get_inner.clone(), col_map, cte_map, context)?
2060                // throw away actual values and just remember whether or not there were __any__ rows
2061                .distinct_by((0..get_inner.arity()).collect())
2062                // Append true to anything that returned any rows.
2063                .map(vec![MirScalarExpr::literal_true()]);
2064
2065            // append False to anything that didn't return any rows
2066            get_inner.lookup(id_gen, exists, vec![(Datum::False, ReprScalarType::Bool)])
2067        },
2068    )
2069}
2070
2071impl AggregateExpr {
2072    fn applied_to(
2073        self,
2074        id_gen: &mut mz_ore::id_gen::IdGen,
2075        col_map: &ColumnMap,
2076        cte_map: &mut CteMap,
2077        inner: &mut MirRelationExpr,
2078        context: &Context,
2079    ) -> Result<mz_expr::AggregateExpr, PlanError> {
2080        let AggregateExpr {
2081            func,
2082            expr,
2083            distinct,
2084        } = self;
2085
2086        Ok(mz_expr::AggregateExpr {
2087            func: func.into_expr(),
2088            expr: expr.applied_to(id_gen, col_map, cte_map, inner, &None, context)?,
2089            distinct,
2090        })
2091    }
2092}
2093
2094/// Attempts an efficient outer join, if `on` has equijoin structure.
2095///
2096/// Both `left` and `right` are decorrelated inputs.
2097///
2098/// The first `oa` columns correspond to an outer context: we should do the
2099/// outer join independently for each prefix. In the case that `on` contains
2100/// just some equality tests between columns of `left` and `right` and some
2101/// local predicates, we can employ a relatively simple plan.
2102///
2103/// The last `on_subquery_types.len()` columns correspond to results from
2104/// subqueries defined in the `on` clause - we treat those as theta-join
2105/// conditions that prohibit the use of the simple plan attempted here.
2106fn attempt_outer_equijoin(
2107    left: MirRelationExpr,
2108    right: MirRelationExpr,
2109    on: MirScalarExpr,
2110    on_subquery_types: Vec<ReprColumnType>,
2111    kind: JoinKind,
2112    oa: usize,
2113    id_gen: &mut mz_ore::id_gen::IdGen,
2114    context: &Context,
2115) -> Result<Option<MirRelationExpr>, PlanError> {
2116    // TODO(database-issues#6827): In theory, we can be smarter and also handle `on`
2117    // predicates that reference subqueries as long as these subqueries don't
2118    // reference `left` and `right` at the same time.
2119    //
2120    // TODO(database-issues#6828): This code can be improved as follows:
2121    //
2122    // 1. Move the `canonicalize_predicates(...)` call to `applied_to`.
2123    // 2. Use the canonicalized `on` predicate in the non-equijoin based
2124    //    lowering strategy.
2125    // 3. Move the `OnPredicates::new(...)` call to `applied_to`.
2126    // 4. Pass the classified `OnPredicates` as a parameter.
2127    // 5. Guard calls of this function with `on_predicates.is_equijoin()`.
2128    //
2129    // Steps (1 + 2) require further investigation because we might change the
2130    // error semantics in case the `on` predicate contains a literal error..
2131
2132    let l_type = left.typ();
2133    let r_type = right.typ();
2134    let la = l_type.column_types.len() - oa;
2135    let ra = r_type.column_types.len() - oa;
2136    let sa = on_subquery_types.len();
2137
2138    // The output type contains [outer, left, right, sa] attributes.
2139    let mut output_type = Vec::with_capacity(oa + la + ra + sa);
2140    output_type.extend(l_type.column_types);
2141    output_type.extend(r_type.column_types.into_iter().skip(oa));
2142    output_type.extend(on_subquery_types);
2143
2144    // Generally healthy to do, but specifically `USING` conditions sometimes
2145    // put an `AND true` at the end of the `ON` condition.
2146    //
2147    // TODO(aalexandrov): maybe we should already be doing this in `applied_to`.
2148    // However, in that case it's not clear that we won't see regressions if
2149    // `on` simplifies to a literal error.
2150    let mut on = vec![on];
2151    mz_expr::canonicalize::canonicalize_predicates(&mut on, &output_type);
2152
2153    // Form the left and right types without the outer attributes.
2154    output_type.drain(0..oa);
2155    let lt = output_type.drain(0..la).collect_vec();
2156    let rt = output_type.drain(0..ra).collect_vec();
2157    assert!(output_type.len() == sa);
2158
2159    let on_predicates = OnPredicates::new(oa, la, ra, sa, on.clone(), context);
2160    if !on_predicates.is_equijoin(context) {
2161        return Ok(None);
2162    }
2163
2164    // If we've gotten this far, we can do the clever thing.
2165    // We'll want to use left and right multiple times
2166    let result = left.let_in(id_gen, |id_gen, get_left| {
2167        right.let_in(id_gen, |id_gen, get_right| {
2168            // TODO: we know that we can re-use the arrangements of left and right
2169            // needed for the inner join with each of the conditional outer joins.
2170            // It is not clear whether we should hint that, or just let the planner
2171            // and optimizer run and see what happens.
2172
2173            // We'll want the inner join (minus repeated columns)
2174            let join = MirRelationExpr::join(
2175                vec![get_left.clone(), get_right.clone()],
2176                (0..oa).map(|i| vec![(0, i), (1, i)]).collect(),
2177            )
2178            // remove those columns from `right` repeating the first `oa` columns.
2179            .project(
2180                (0..(oa + la))
2181                    .chain((oa + la + oa)..(oa + la + oa + ra))
2182                    .collect(),
2183            )
2184            // apply the filter constraints here, to ensure nulls are not matched.
2185            .filter(on);
2186
2187            // We'll want to re-use the results of the join multiple times.
2188            join.let_in(id_gen, |id_gen, get_join| {
2189                let mut result = get_join.clone();
2190
2191                // A collection of keys present in both left and right collections.
2192                let join_keys = on_predicates.join_keys();
2193                let both_keys_arity = join_keys.len();
2194                let both_keys = get_join.restrict(join_keys).distinct();
2195
2196                // The plan is now to determine the left and right rows matched in the
2197                // inner join, subtract them from left and right respectively, pad what
2198                // remains with nulls, and fold them in to `result`.
2199
2200                both_keys.let_in(id_gen, |_id_gen, get_both| {
2201                    if let JoinKind::LeftOuter { .. } | JoinKind::FullOuter = kind {
2202                        // Rows in `left` matched in the inner equijoin. This is
2203                        // a semi-join between `left` and `both_keys`.
2204                        let left_present = MirRelationExpr::join_scalars(
2205                            vec![
2206                                get_left
2207                                    .clone()
2208                                    // Push local predicates.
2209                                    .filter(on_predicates.lhs()),
2210                                get_both.clone(),
2211                            ],
2212                            itertools::zip_eq(
2213                                on_predicates.eq_lhs(),
2214                                (0..both_keys_arity).map(|k| MirScalarExpr::column(oa + la + k)),
2215                            )
2216                            .map(|(l_key, b_key)| [l_key, b_key].to_vec())
2217                            .collect(),
2218                        )
2219                        .project((0..(oa + la)).collect());
2220
2221                        // Determine the types of nulls to use as filler.
2222                        let right_fill = rt
2223                            .into_iter()
2224                            .map(|typ| MirScalarExpr::literal_null(typ.scalar_type))
2225                            .collect();
2226                        // Add to `result` absent elements, filled with typed nulls.
2227                        result = left_present
2228                            .negate()
2229                            .union(get_left.clone())
2230                            .map(right_fill)
2231                            .union(result);
2232                    }
2233
2234                    if let JoinKind::RightOuter | JoinKind::FullOuter = kind {
2235                        // Rows in `right` matched in the inner equijoin. This
2236                        // is a semi-join between `right` and `both_keys`.
2237                        let right_present = MirRelationExpr::join_scalars(
2238                            vec![
2239                                get_right
2240                                    .clone()
2241                                    // Push local predicates.
2242                                    .filter(on_predicates.rhs()),
2243                                get_both,
2244                            ],
2245                            itertools::zip_eq(
2246                                on_predicates.eq_rhs(),
2247                                (0..both_keys_arity).map(|k| MirScalarExpr::column(oa + ra + k)),
2248                            )
2249                            .map(|(r_key, b_key)| [r_key, b_key].to_vec())
2250                            .collect(),
2251                        )
2252                        .project((0..(oa + ra)).collect());
2253
2254                        // Determine the types of nulls to use as filler.
2255                        let left_fill = lt
2256                            .into_iter()
2257                            .map(|typ| MirScalarExpr::literal_null(typ.scalar_type))
2258                            .collect();
2259
2260                        // Add to `result` absent elements, prepended with typed nulls.
2261                        result = right_present
2262                            .negate()
2263                            .union(get_right.clone())
2264                            .map(left_fill)
2265                            // Permute left fill before right values.
2266                            .project(
2267                                itertools::chain!(
2268                                    0..oa,                 // Preserve `outer`.
2269                                    oa + ra..oa + la + ra, // Increment the next `la` cols by `ra`.
2270                                    oa..oa + ra            // Decrement the next `ra` cols by `la`.
2271                                )
2272                                .collect(),
2273                            )
2274                            .union(result)
2275                    }
2276
2277                    Ok::<_, PlanError>(result)
2278                })
2279            })
2280        })
2281    })?;
2282    Ok(Some(result))
2283}
2284
2285/// A struct that represents the predicates in the `on` clause in a form
2286/// suitable for efficient planning outer joins with equijoin predicates.
2287struct OnPredicates {
2288    /// A store for classified `ON` predicates.
2289    ///
2290    /// Predicates that reference a single side are adjusted to assume an
2291    /// `outer × <side>` schema.
2292    predicates: Vec<OnPredicate>,
2293    /// Number of outer context columns.
2294    oa: usize,
2295}
2296
2297impl OnPredicates {
2298    const I_OUT: usize = 0; // outer context input position
2299    const I_LHS: usize = 1; // lhs input position
2300    const I_RHS: usize = 2; // rhs input position
2301    const I_SUB: usize = 3; // on subqueries input position
2302
2303    /// Classify the predicates in the `on` clause of an outer join.
2304    ///
2305    /// The other parameters are arities of the input parts:
2306    ///
2307    /// - `oa` is the arity of the `outer` context.
2308    /// - `la` is the arity of the `left` input.
2309    /// - `ra` is the arity of the `right` input.
2310    /// - `sa` is the arity of the `on` subqueries.
2311    ///
2312    /// The constructor assumes that:
2313    ///
2314    /// 1. The `on` parameter will be applied on a result that has the following
2315    ///    schema `outer × left × right × on_subqueries`.
2316    /// 2. The `on` parameter is already adjusted to assume that schema.
2317    /// 3. The `on` parameter is obtained by canonicalizing the original `on:
2318    ///    MirScalarExpr` with `canonicalize_predicates`.
2319    fn new(
2320        oa: usize,
2321        la: usize,
2322        ra: usize,
2323        sa: usize,
2324        on: Vec<MirScalarExpr>,
2325        _context: &Context,
2326    ) -> Self {
2327        use mz_expr::BinaryFunc::Eq;
2328
2329        // Re-bind those locally for more compact pattern matching.
2330        const I_LHS: usize = OnPredicates::I_LHS;
2331        const I_RHS: usize = OnPredicates::I_RHS;
2332
2333        // Self parameters.
2334        let mut predicates = Vec::with_capacity(on.len());
2335
2336        // Helpers for populating `predicates`.
2337        let inner_join_mapper = mz_expr::JoinInputMapper::new_from_input_arities([oa, la, ra, sa]);
2338        let rhs_permutation = itertools::chain!(0..oa + la, oa..oa + ra).collect::<Vec<_>>();
2339        let lookup_inputs = |expr: &MirScalarExpr| -> Vec<usize> {
2340            inner_join_mapper
2341                .lookup_inputs(expr)
2342                .filter(|&i| i != Self::I_OUT)
2343                .collect()
2344        };
2345        let has_subquery_refs = |expr: &MirScalarExpr| -> bool {
2346            inner_join_mapper
2347                .lookup_inputs(expr)
2348                .any(|i| i == Self::I_SUB)
2349        };
2350
2351        // Iterate over `on` elements and populate `predicates`.
2352        for mut predicate in on {
2353            if predicate.might_error() {
2354                tracing::debug!(case = "thetajoin (error)", "OnPredicates::new");
2355                // Treat predicates that can produce a literal error as Theta.
2356                predicates.push(OnPredicate::Theta(predicate));
2357            } else if has_subquery_refs(&predicate) {
2358                tracing::debug!(case = "thetajoin (subquery)", "OnPredicates::new");
2359                // Treat predicates referencing an `on` subquery as Theta.
2360                predicates.push(OnPredicate::Theta(predicate));
2361            } else if let MirScalarExpr::CallBinary {
2362                func: Eq(_),
2363                expr1,
2364                expr2,
2365            } = &mut predicate
2366            {
2367                // Obtain the non-outer inputs referenced by each side.
2368                let inputs1 = lookup_inputs(expr1);
2369                let inputs2 = lookup_inputs(expr2);
2370
2371                match (&inputs1[..], &inputs2[..]) {
2372                    // Neither side references an input. This could be a
2373                    // constant expression or an expression that depends only on
2374                    // the outer context.
2375                    ([], []) => {
2376                        predicates.push(OnPredicate::Const(predicate));
2377                    }
2378                    // Both sides reference different inputs.
2379                    ([I_LHS], [I_RHS]) => {
2380                        let lhs = expr1.take();
2381                        let mut rhs = expr2.take();
2382                        rhs.permute(&rhs_permutation);
2383                        predicates.push(OnPredicate::Eq(lhs.clone(), rhs.clone()));
2384                        predicates.push(OnPredicate::LhsConsequence(lhs.call_is_null().not()));
2385                        predicates.push(OnPredicate::RhsConsequence(rhs.call_is_null().not()));
2386                    }
2387                    // Both sides reference different inputs (swapped).
2388                    ([I_RHS], [I_LHS]) => {
2389                        let lhs = expr2.take();
2390                        let mut rhs = expr1.take();
2391                        rhs.permute(&rhs_permutation);
2392                        predicates.push(OnPredicate::Eq(lhs.clone(), rhs.clone()));
2393                        predicates.push(OnPredicate::LhsConsequence(lhs.call_is_null().not()));
2394                        predicates.push(OnPredicate::RhsConsequence(rhs.call_is_null().not()));
2395                    }
2396                    // Both sides reference the left input or no input.
2397                    ([I_LHS], [I_LHS]) | ([I_LHS], []) | ([], [I_LHS]) => {
2398                        predicates.push(OnPredicate::Lhs(predicate));
2399                    }
2400                    // Both sides reference the right input or no input.
2401                    ([I_RHS], [I_RHS]) | ([I_RHS], []) | ([], [I_RHS]) => {
2402                        predicate.permute(&rhs_permutation);
2403                        predicates.push(OnPredicate::Rhs(predicate));
2404                    }
2405                    // At least one side references more than one input.
2406                    _ => {
2407                        tracing::debug!(case = "thetajoin (eq)", "OnPredicates::new");
2408                        predicates.push(OnPredicate::Theta(predicate));
2409                    }
2410                }
2411            } else {
2412                // Obtain the non-outer inputs referenced by this predicate.
2413                let inputs = lookup_inputs(&predicate);
2414
2415                match &inputs[..] {
2416                    // The predicate references no inputs. This could be a
2417                    // constant expression or an expression that depends only on
2418                    // the outer context.
2419                    [] => {
2420                        predicates.push(OnPredicate::Const(predicate));
2421                    }
2422                    // The predicate references only the left input.
2423                    [I_LHS] => {
2424                        predicates.push(OnPredicate::Lhs(predicate));
2425                    }
2426                    // The predicate references only the right input.
2427                    [I_RHS] => {
2428                        predicate.permute(&rhs_permutation);
2429                        predicates.push(OnPredicate::Rhs(predicate));
2430                    }
2431                    // The predicate references both inputs.
2432                    _ => {
2433                        tracing::debug!(case = "thetajoin (non-eq)", "OnPredicates::new");
2434                        predicates.push(OnPredicate::Theta(predicate));
2435                    }
2436                }
2437            }
2438        }
2439
2440        Self { predicates, oa }
2441    }
2442
2443    /// Check if the predicates can be lowered with an equijoin-based strategy.
2444    fn is_equijoin(&self, context: &Context) -> bool {
2445        // Count each `OnPredicate` variant in `self.predicates`.
2446        let (const_cnt, lhs_cnt, rhs_cnt, eq_cnt, eq_cols, theta_cnt) =
2447            self.predicates.iter().fold(
2448                (0, 0, 0, 0, 0, 0),
2449                |(const_cnt, lhs_cnt, rhs_cnt, eq_cnt, eq_cols, theta_cnt), p| {
2450                    (
2451                        const_cnt + usize::from(matches!(p, OnPredicate::Const(..))),
2452                        lhs_cnt + usize::from(matches!(p, OnPredicate::Lhs(..))),
2453                        rhs_cnt + usize::from(matches!(p, OnPredicate::Rhs(..))),
2454                        eq_cnt + usize::from(matches!(p, OnPredicate::Eq(..))),
2455                        eq_cols
2456                            + usize::from(matches!(
2457                                p,
2458                                OnPredicate::Eq(lhs, rhs) if lhs.is_column() && rhs.is_column()
2459                            )),
2460                        theta_cnt + usize::from(matches!(p, OnPredicate::Theta(..))),
2461                    )
2462                },
2463            );
2464
2465        let is_equijion = if context.config.enable_new_outer_join_lowering {
2466            // New classifier.
2467            eq_cnt > 0 && theta_cnt == 0
2468        } else {
2469            // Old classifier.
2470            eq_cnt > 0 && eq_cnt == eq_cols && theta_cnt + const_cnt + lhs_cnt + rhs_cnt == 0
2471        };
2472
2473        // Log an entry only if this is an equijoin according to the new classifier.
2474        if eq_cnt > 0 && theta_cnt == 0 {
2475            tracing::debug!(
2476                const_cnt,
2477                lhs_cnt,
2478                rhs_cnt,
2479                eq_cnt,
2480                eq_cols,
2481                theta_cnt,
2482                "OnPredicates::is_equijoin"
2483            );
2484        }
2485
2486        is_equijion
2487    }
2488
2489    /// Return an [`MirRelationExpr`] list that represents the keys for the
2490    /// equijoin. The list will contain the outer columns as a prefix.
2491    fn join_keys(&self) -> JoinKeys {
2492        // We could return either the `lhs` or the `rhs` of the keys used to
2493        // form the inner join as they are equated by the join condition.
2494        let join_keys = self.eq_lhs().collect::<Vec<_>>();
2495
2496        if join_keys.iter().all(|k| k.is_column()) {
2497            tracing::debug!(case = "outputs", "OnPredicates::join_keys");
2498            JoinKeys::Outputs(join_keys.iter().flat_map(|k| k.as_column()).collect())
2499        } else {
2500            tracing::debug!(case = "scalars", "OnPredicates::join_keys");
2501            JoinKeys::Scalars(join_keys)
2502        }
2503    }
2504
2505    /// Return an iterator over the left-hand sides of all [`OnPredicate::Eq`]
2506    /// conditions in the predicates list.
2507    ///
2508    /// The iterator will start with column references to the outer columns as a
2509    /// prefix.
2510    fn eq_lhs(&self) -> impl Iterator<Item = MirScalarExpr> + '_ {
2511        itertools::chain(
2512            (0..self.oa).map(MirScalarExpr::column),
2513            self.predicates.iter().filter_map(|e| match e {
2514                OnPredicate::Eq(lhs, _) => Some(lhs.clone()),
2515                _ => None,
2516            }),
2517        )
2518    }
2519
2520    /// Return an iterator over the right-hand sides of all [`OnPredicate::Eq`]
2521    /// conditions in the predicates list.
2522    ///
2523    /// The iterator will start with column references to the outer columns as a
2524    /// prefix.
2525    fn eq_rhs(&self) -> impl Iterator<Item = MirScalarExpr> + '_ {
2526        itertools::chain(
2527            (0..self.oa).map(MirScalarExpr::column),
2528            self.predicates.iter().filter_map(|e| match e {
2529                OnPredicate::Eq(_, rhs) => Some(rhs.clone()),
2530                _ => None,
2531            }),
2532        )
2533    }
2534
2535    /// Return an iterator over the [`OnPredicate::Lhs`], [`OnPredicate::LhsConsequence`] and
2536    /// [`OnPredicate::Const`] conditions in the predicates list.
2537    fn lhs(&self) -> impl Iterator<Item = MirScalarExpr> + '_ {
2538        self.predicates.iter().filter_map(|p| match p {
2539            // We treat Const predicates local to both inputs.
2540            OnPredicate::Const(p) => Some(p.clone()),
2541            OnPredicate::Lhs(p) => Some(p.clone()),
2542            OnPredicate::LhsConsequence(p) => Some(p.clone()),
2543            _ => None,
2544        })
2545    }
2546
2547    /// Return an iterator over the [`OnPredicate::Rhs`], [`OnPredicate::RhsConsequence`] and
2548    /// [`OnPredicate::Const`] conditions in the predicates list.
2549    fn rhs(&self) -> impl Iterator<Item = MirScalarExpr> + '_ {
2550        self.predicates.iter().filter_map(|p| match p {
2551            // We treat Const predicates local to both inputs.
2552            OnPredicate::Const(p) => Some(p.clone()),
2553            OnPredicate::Rhs(p) => Some(p.clone()),
2554            OnPredicate::RhsConsequence(p) => Some(p.clone()),
2555            _ => None,
2556        })
2557    }
2558}
2559
2560enum OnPredicate {
2561    // A predicate that is either constant or references only outer columns.
2562    Const(MirScalarExpr),
2563    // A local predicate on the left-hand side of the join, i.e., it references only the left input
2564    // and possibly outer columns.
2565    //
2566    // This is one of the original predicates from the ON clause.
2567    //
2568    // One _must_ apply this predicate.
2569    Lhs(MirScalarExpr),
2570    // A local predicate on the left-hand side of the join, i.e., it references only the left input
2571    // and possibly outer columns.
2572    //
2573    // This is not one of the original predicates from the ON clause, but is just a consequence
2574    // of an original predicate in the ON clause, where the original predicate references both
2575    // inputs, but the consequence references only the left input.
2576    //
2577    // For example, the original predicate `input1.x = input2.a` has the consequence
2578    // `input1.x IS NOT NULL`. Applying such a consequence before the input is fed into the join
2579    // prevents null skew, and also makes more CSE opportunities available when the left input's key
2580    // doesn't have a NOT NULL constraint, saving us an arrangement.
2581    //
2582    // Applying the predicate is optional, because the original predicate will be applied anyway.
2583    LhsConsequence(MirScalarExpr),
2584    // A local predicate on the right-hand side of the join.
2585    //
2586    // This is one of the original predicates from the ON clause.
2587    //
2588    // One _must_ apply this predicate.
2589    Rhs(MirScalarExpr),
2590    // A consequence of an original ON predicate, see above.
2591    RhsConsequence(MirScalarExpr),
2592    // An equality predicate between the two sides.
2593    Eq(MirScalarExpr, MirScalarExpr),
2594    // a non-equality predicate between the two sides.
2595    #[allow(dead_code)]
2596    Theta(MirScalarExpr),
2597}
2598
2599/// A set of join keys referencing an input.
2600///
2601/// This is used in the [`MirRelationExpr::Join`] lowering code in order to
2602/// avoid changes (and thereby possible regressions) in plans that have equijoin
2603/// predicates consisting only of column refs.
2604///
2605/// If we were running `CanonicalizeMfp` as part of `NormalizeOps` we might be
2606/// able to get rid of this code, but as it stands `Map` simplification seems
2607/// more cumbersome than `Project` simplification, so do this just to be sure.
2608enum JoinKeys {
2609    // A predicate that is either constant or references only outer columns.
2610    Outputs(Vec<usize>),
2611    // A local predicate on the left-hand side of the join.
2612    Scalars(Vec<MirScalarExpr>),
2613}
2614
2615impl JoinKeys {
2616    fn len(&self) -> usize {
2617        match self {
2618            JoinKeys::Outputs(outputs) => outputs.len(),
2619            JoinKeys::Scalars(scalars) => scalars.len(),
2620        }
2621    }
2622}
2623
2624/// Extension methods for [`MirRelationExpr`] required in the HIR ⇒ MIR lowering
2625/// code.
2626trait LoweringExt {
2627    /// See [`MirRelationExpr::restrict`].
2628    fn restrict(self, join_keys: JoinKeys) -> Self;
2629}
2630
2631impl LoweringExt for MirRelationExpr {
2632    /// Restrict the set of columns of an input to the sequence of [`JoinKeys`].
2633    fn restrict(self, join_keys: JoinKeys) -> Self {
2634        let num_keys = join_keys.len();
2635        match join_keys {
2636            JoinKeys::Outputs(outputs) => self.project(outputs),
2637            JoinKeys::Scalars(scalars) => {
2638                let input_arity = self.arity();
2639                let outputs = (input_arity..input_arity + num_keys).collect();
2640                self.map(scalars).project(outputs)
2641            }
2642        }
2643    }
2644}