mz_sql/plan/
lowering.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Lowering is the process of transforming a `HirRelationExpr`
11//! into a `MirRelationExpr`.
12//!
13//! The most crucial part of lowering is decorrelation; i.e.: rewriting a
14//! `HirScalarExpr` that may contain subqueries (e.g. `SELECT` or `EXISTS`)
15//! with instances of `MirScalarExpr` that contain none of these.
16//!
17//! Informally, a subquery should be viewed as a query that is executed in
18//! the context of some outer relation, for each row of that relation. The
19//! subqueries often contain references to the columns of the outer
20//! relation.
21//!
22//! The transformation we perform maintains an `outer` relation and then
23//! traverses the relation expression that may contain references to those
24//! outer columns. As subqueries are discovered, the current relation
25//! expression is recast as the outer expression until such a point as the
26//! scalar expression's evaluation can be determined and appended to each
27//! row of the previously outer relation.
28//!
29//! It is important that the outer columns (the initial columns) act as keys
30//! for all nested computation. When counts or other aggregations are
31//! performed, they should include not only the indicated keys but also all
32//! of the outer columns.
33//!
34//! The decorrelation transformation is initialized with an empty outer
35//! relation, but it seems entirely appropriate to decorrelate queries that
36//! contain "holes" from prepared statements, as if the query was a subquery
37//! against a relation containing the assignments of values to those holes.
38
39use std::collections::{BTreeMap, BTreeSet};
40use std::iter::repeat;
41
42use itertools::Itertools;
43use mz_expr::visit::Visit;
44use mz_expr::{AccessStrategy, AggregateFunc, MirRelationExpr, MirScalarExpr, func};
45use mz_ore::collections::CollectionExt;
46use mz_ore::stack::maybe_grow;
47use mz_repr::*;
48
49use crate::optimizer_metrics::OptimizerMetrics;
50use crate::plan::hir::{
51    AggregateExpr, ColumnOrder, ColumnRef, HirRelationExpr, HirScalarExpr, JoinKind, WindowExprType,
52};
53use crate::plan::{PlanError, transform_hir};
54use crate::session::vars::SystemVars;
55
56mod variadic_left;
57
58/// Maps a leveled column reference to a specific column.
59///
60/// Leveled column references are nested, so that larger levels are
61/// found early in a record and level zero is found at the end.
62///
63/// The column map only stores references for levels greater than zero,
64/// and column references at level zero simply start at the first column
65/// after all prior references.
66#[derive(Debug, Clone)]
67struct ColumnMap {
68    inner: BTreeMap<ColumnRef, usize>,
69}
70
71impl ColumnMap {
72    fn empty() -> ColumnMap {
73        Self::new(BTreeMap::new())
74    }
75
76    fn new(inner: BTreeMap<ColumnRef, usize>) -> ColumnMap {
77        ColumnMap { inner }
78    }
79
80    fn get(&self, col_ref: &ColumnRef) -> usize {
81        if col_ref.level == 0 {
82            self.inner.len() + col_ref.column
83        } else {
84            self.inner[col_ref]
85        }
86    }
87
88    fn len(&self) -> usize {
89        self.inner.len()
90    }
91
92    /// Updates references in the `ColumnMap` for use in a nested scope. The
93    /// provided `arity` must specify the arity of the current scope.
94    fn enter_scope(&self, arity: usize) -> ColumnMap {
95        // From the perspective of the nested scope, all existing column
96        // references will be one level greater.
97        let existing = self
98            .inner
99            .clone()
100            .into_iter()
101            .update(|(col, _i)| col.level += 1);
102
103        // All columns in the current scope become explicit entries in the
104        // immediate parent scope.
105        let new = (0..arity).map(|i| {
106            (
107                ColumnRef {
108                    level: 1,
109                    column: i,
110                },
111                self.len() + i,
112            )
113        });
114
115        ColumnMap::new(existing.chain(new).collect())
116    }
117}
118
119/// Map with the CTEs currently in scope.
120type CteMap = BTreeMap<mz_expr::LocalId, CteDesc>;
121
122/// Information about needed when finding a reference to a CTE in scope.
123#[derive(Clone)]
124struct CteDesc {
125    /// The new ID assigned to the lowered version of the CTE, which may not match
126    /// the ID of the input CTE.
127    new_id: mz_expr::LocalId,
128    /// The relation type of the CTE including the columns from the outer
129    /// context at the beginning.
130    relation_type: SqlRelationType,
131    /// The outer relation the CTE was applied to.
132    outer_relation: MirRelationExpr,
133}
134
135#[derive(Debug, Clone, Copy)]
136pub struct Config {
137    /// Enable outer join lowering implemented in database-issues#6747.
138    pub enable_new_outer_join_lowering: bool,
139    /// Enable outer join lowering implemented in database-issues#7561.
140    pub enable_variadic_left_join_lowering: bool,
141    pub enable_guard_subquery_tablefunc: bool,
142    pub enable_cast_elimination: bool,
143}
144
145impl Default for Config {
146    fn default() -> Self {
147        Self {
148            enable_new_outer_join_lowering: false,
149            enable_variadic_left_join_lowering: false,
150            enable_guard_subquery_tablefunc: false,
151            enable_cast_elimination: false,
152        }
153    }
154}
155
156impl From<&SystemVars> for Config {
157    fn from(vars: &SystemVars) -> Self {
158        Self {
159            enable_new_outer_join_lowering: vars.enable_new_outer_join_lowering(),
160            enable_variadic_left_join_lowering: vars.enable_variadic_left_join_lowering(),
161            enable_guard_subquery_tablefunc: vars.enable_guard_subquery_tablefunc(),
162            enable_cast_elimination: vars.enable_cast_elimination(),
163        }
164    }
165}
166
167/// Context passed to the lowering. This is wired to most parts of the lowering.
168pub(crate) struct Context<'a> {
169    /// Feature flags affecting the behavior of lowering.
170    pub config: &'a Config,
171    /// Optional, because some callers don't have an `OptimizerMetrics` handy. When it's None, we
172    /// simply don't write metrics.
173    pub metrics: Option<&'a OptimizerMetrics>,
174}
175
176impl HirRelationExpr {
177    /// Rewrite `self` into a `MirRelationExpr`.
178    /// This requires rewriting all correlated subqueries (nested `HirRelationExpr`s) into flat queries
179    #[mz_ore::instrument(target = "optimizer", level = "trace", name = "hir_to_mir")]
180    pub fn lower<C: Into<Config>>(
181        self,
182        config: C,
183        metrics: Option<&OptimizerMetrics>,
184    ) -> Result<MirRelationExpr, PlanError> {
185        let context = Context {
186            config: &config.into(),
187            metrics,
188        };
189        let result = match self {
190            // We directly rewrite a Constant into the corresponding `MirRelationExpr::Constant`
191            // to ensure that the downstream optimizer can easily bypass most
192            // irrelevant optimizations (e.g. reduce folding) for this expression
193            // without having to re-learn the fact that it is just a constant,
194            // as it would if the constant were wrapped in a Let-Get pair.
195            HirRelationExpr::Constant { rows, typ } => {
196                let rows: Vec<_> = rows.into_iter().map(|row| (row, Diff::ONE)).collect();
197                MirRelationExpr::Constant {
198                    rows: Ok(rows),
199                    typ,
200                }
201            }
202            mut other => {
203                let mut id_gen = mz_ore::id_gen::IdGen::default();
204                transform_hir::split_subquery_predicates(&mut other)?;
205                transform_hir::try_simplify_quantified_comparisons(&mut other)?;
206                transform_hir::fuse_window_functions(&mut other, &context)?;
207                MirRelationExpr::constant(vec![vec![]], SqlRelationType::new(vec![])).let_in(
208                    &mut id_gen,
209                    |id_gen, get_outer| {
210                        other.applied_to(
211                            id_gen,
212                            get_outer,
213                            &ColumnMap::empty(),
214                            &mut CteMap::new(),
215                            &context,
216                        )
217                    },
218                )?
219            }
220        };
221
222        mz_repr::explain::trace_plan(&result);
223
224        Ok(result)
225    }
226
227    /// Return a `MirRelationExpr` which evaluates `self` once for each row of `get_outer`.
228    ///
229    /// For uncorrelated `self`, this should be the cross-product between `get_outer` and `self`.
230    /// When `self` references columns of `get_outer`, much more work needs to occur.
231    ///
232    /// The `col_map` argument contains mappings to some of the columns of `get_outer`, though
233    /// perhaps not all of them. It should be used as the basis of resolving column references,
234    /// but care must be taken when adding new columns that `get_outer.arity()` is where they
235    /// will start, rather than any function of `col_map`.
236    ///
237    /// The `get_outer` expression should be a `Get` with no duplicate rows, describing the distinct
238    /// assignment of values to outer rows.
239    fn applied_to(
240        self,
241        id_gen: &mut mz_ore::id_gen::IdGen,
242        get_outer: MirRelationExpr,
243        col_map: &ColumnMap,
244        cte_map: &mut CteMap,
245        context: &Context,
246    ) -> Result<MirRelationExpr, PlanError> {
247        maybe_grow(|| {
248            use MirRelationExpr as SR;
249
250            use HirRelationExpr::*;
251
252            if let MirRelationExpr::Get { .. } = &get_outer {
253            } else {
254                panic!(
255                    "get_outer: expected a MirRelationExpr::Get, found\n{}",
256                    get_outer.pretty(),
257                );
258            }
259            assert_eq!(col_map.len(), get_outer.arity());
260            Ok(match self {
261                Constant { rows, typ } => {
262                    // Constant expressions are not correlated with `get_outer`, and should be cross-products.
263                    get_outer.product(SR::Constant {
264                        rows: Ok(rows.into_iter().map(|row| (row, Diff::ONE)).collect()),
265                        typ,
266                    })
267                }
268                Get { id, typ } => match id {
269                    mz_expr::Id::Local(local_id) => {
270                        let cte_desc = cte_map.get(&local_id).unwrap();
271                        let get_cte = SR::Get {
272                            id: mz_expr::Id::Local(cte_desc.new_id.clone()),
273                            typ: cte_desc.relation_type.clone(),
274                            access_strategy: AccessStrategy::UnknownOrLocal,
275                        };
276                        if get_outer == cte_desc.outer_relation {
277                            // If the CTE was applied to the same exact relation, we can safely
278                            // return a `Get` relation.
279                            get_cte
280                        } else {
281                            // Otherwise, the new outer relation may contain more columns from some
282                            // intermediate scope placed between the definition of the CTE and this
283                            // reference of the CTE and/or more operations applied on top of the
284                            // outer relation.
285                            //
286                            // An example of the latter is the following query:
287                            //
288                            // SELECT *
289                            // FROM x,
290                            //      LATERAL(WITH a(m) as (SELECT max(y.a) FROM y WHERE y.a < x.a)
291                            //              SELECT (SELECT m FROM a) FROM y) b;
292                            //
293                            // When the CTE is lowered, the outer relation is `Get x`. But then,
294                            // the reference of the CTE is applied to `Distinct(Join(Get x, Get y), x.*)`
295                            // which has the same cardinality as `Get x`.
296                            //
297                            // In any case, `get_outer` is guaranteed to contain the columns of the
298                            // outer relation the CTE was applied to at its prefix. Since, we must
299                            // return a relation containing `get_outer`'s column at the beginning,
300                            // we must build a join between `get_outer` and `get_cte` on their common
301                            // columns.
302                            let oa = get_outer.arity();
303                            let cte_outer_columns = cte_desc.relation_type.arity() - typ.arity();
304                            let equivalences = (0..cte_outer_columns)
305                                .map(|pos| {
306                                    vec![
307                                        MirScalarExpr::column(pos),
308                                        MirScalarExpr::column(pos + oa),
309                                    ]
310                                })
311                                .collect();
312
313                            // Project out the second copy of the common between `get_outer` and
314                            // `cte_desc.outer_relation`.
315                            let projection = (0..oa)
316                                .chain(oa + cte_outer_columns..oa + cte_outer_columns + typ.arity())
317                                .collect_vec();
318                            SR::join_scalars(vec![get_outer, get_cte], equivalences)
319                                .project(projection)
320                        }
321                    }
322                    _ => {
323                        // Get statements are only to external sources, and are not correlated with `get_outer`.
324                        get_outer.product(SR::Get {
325                            id,
326                            typ,
327                            access_strategy: AccessStrategy::UnknownOrLocal,
328                        })
329                    }
330                },
331                Let {
332                    name: _,
333                    id,
334                    value,
335                    body,
336                } => {
337                    let value =
338                        value.applied_to(id_gen, get_outer.clone(), col_map, cte_map, context)?;
339                    value.let_in(id_gen, |id_gen, get_value| {
340                        let (new_id, typ) = if let MirRelationExpr::Get {
341                            id: mz_expr::Id::Local(id),
342                            typ,
343                            ..
344                        } = get_value
345                        {
346                            (id, typ)
347                        } else {
348                            panic!(
349                                "get_value: expected a MirRelationExpr::Get with local Id, found\n{}",
350                                get_value.pretty(),
351                            );
352                        };
353                        // Add the information about the CTE to the map and remove it when
354                        // it goes out of scope.
355                        let old_value = cte_map.insert(
356                            id.clone(),
357                            CteDesc {
358                                new_id,
359                                relation_type: typ,
360                                outer_relation: get_outer.clone(),
361                            },
362                        );
363                        let body = body.applied_to(id_gen, get_outer, col_map, cte_map, context);
364                        if let Some(old_value) = old_value {
365                            cte_map.insert(id, old_value);
366                        } else {
367                            cte_map.remove(&id);
368                        }
369                        body
370                    })?
371                }
372                LetRec {
373                    limit,
374                    bindings,
375                    body,
376                } => {
377                    let num_bindings = bindings.len();
378
379                    // We use the outer type with the HIR types to form MIR CTE types.
380                    let outer_column_types = get_outer.typ().column_types;
381
382                    // Rename and introduce all bindings.
383                    let mut shadowed_bindings = Vec::with_capacity(num_bindings);
384                    let mut mir_ids = Vec::with_capacity(num_bindings);
385                    for (_name, id, _value, typ) in bindings.iter() {
386                        let mir_id = mz_expr::LocalId::new(id_gen.allocate_id());
387                        mir_ids.push(mir_id);
388                        let shadowed = cte_map.insert(
389                            id.clone(),
390                            CteDesc {
391                                new_id: mir_id,
392                                relation_type: SqlRelationType::new(
393                                    outer_column_types
394                                        .iter()
395                                        .cloned()
396                                        .chain(typ.column_types.iter().cloned())
397                                        .collect::<Vec<_>>(),
398                                ),
399                                outer_relation: get_outer.clone(),
400                            },
401                        );
402                        shadowed_bindings.push((*id, shadowed));
403                    }
404
405                    let mut mir_values = Vec::with_capacity(num_bindings);
406                    for (_name, _id, value, _typ) in bindings.into_iter() {
407                        mir_values.push(value.applied_to(
408                            id_gen,
409                            get_outer.clone(),
410                            col_map,
411                            cte_map,
412                            context,
413                        )?);
414                    }
415
416                    let mir_body = body.applied_to(id_gen, get_outer, col_map, cte_map, context)?;
417
418                    // Remove our bindings and reinstate any shadowed bindings.
419                    for (id, shadowed) in shadowed_bindings {
420                        if let Some(shadowed) = shadowed {
421                            cte_map.insert(id, shadowed);
422                        } else {
423                            cte_map.remove(&id);
424                        }
425                    }
426
427                    MirRelationExpr::LetRec {
428                        ids: mir_ids,
429                        values: mir_values,
430                        // Copy the limit to each binding.
431                        limits: repeat(limit).take(num_bindings).collect(),
432                        body: Box::new(mir_body),
433                    }
434                }
435                Project { input, outputs } => {
436                    // Projections should be applied to the decorrelated `inner`, and to its columns,
437                    // which means rebasing `outputs` to start `get_outer.arity()` columns later.
438                    let input =
439                        input.applied_to(id_gen, get_outer.clone(), col_map, cte_map, context)?;
440                    let outputs = (0..get_outer.arity())
441                        .chain(outputs.into_iter().map(|i| get_outer.arity() + i))
442                        .collect::<Vec<_>>();
443                    input.project(outputs)
444                }
445                Map { input, mut scalars } => {
446                    // Scalar expressions may contain correlated subqueries. We must be cautious!
447
448                    // We lower scalars in chunks, and must keep track of the
449                    // arity of the HIR fragments lowered so far.
450                    let mut lowered_arity = input.arity();
451
452                    let mut input =
453                        input.applied_to(id_gen, get_outer, col_map, cte_map, context)?;
454
455                    // Lower subqueries in maximally sized batches, such as no subquery in the current
456                    // batch depends on columns from the same batch.
457                    // Note that subqueries in this projection may reference columns added by this
458                    // Map operator, so we need to ensure these columns exist before lowering the
459                    // subquery.
460                    while !scalars.is_empty() {
461                        let end_idx = scalars
462                            .iter_mut()
463                            .position(|s| {
464                                let mut requires_nonexistent_column = false;
465                                #[allow(deprecated)]
466                                s.visit_columns(0, &mut |depth, col| {
467                                    if col.level == depth {
468                                        requires_nonexistent_column |= col.column >= lowered_arity
469                                    }
470                                });
471                                requires_nonexistent_column
472                            })
473                            .unwrap_or(scalars.len());
474                        assert!(
475                            end_idx > 0,
476                            "a Map expression references itself or a later column; lowered_arity: {}, expressions: {:?}",
477                            lowered_arity,
478                            scalars
479                        );
480
481                        lowered_arity = lowered_arity + end_idx;
482                        let scalars = scalars.drain(0..end_idx).collect_vec();
483
484                        let old_arity = input.arity();
485                        let (with_subqueries, subquery_map) = HirScalarExpr::lower_subqueries(
486                            &scalars, id_gen, col_map, cte_map, input, context,
487                        )?;
488                        input = with_subqueries;
489
490                        // We will proceed sequentially through the scalar expressions, for each transforming
491                        // the decorrelated `input` into a relation with potentially more columns capable of
492                        // addressing the needs of the scalar expression.
493                        // Having done so, we add the scalar value of interest and trim off any other newly
494                        // added columns.
495                        //
496                        // The sequential traversal is present as expressions are allowed to depend on the
497                        // values of prior expressions.
498                        let mut scalar_columns = Vec::new();
499                        for scalar in scalars {
500                            let scalar = scalar.applied_to(
501                                id_gen,
502                                col_map,
503                                cte_map,
504                                &mut input,
505                                &Some(&subquery_map),
506                                context,
507                            )?;
508                            input = input.map_one(scalar);
509                            scalar_columns.push(input.arity() - 1);
510                        }
511
512                        // Discard any new columns added by the lowering of the scalar expressions
513                        input = input.project((0..old_arity).chain(scalar_columns).collect());
514                    }
515
516                    input
517                }
518                CallTable { func, exprs } => {
519                    // FlatMap expressions may contain correlated subqueries. Unlike Map they are not
520                    // allowed to refer to the results of previous expressions, and we have a simpler
521                    // implementation that appends all relevant columns first, then applies the flatmap
522                    // operator to the result, then strips off any columns introduce by subqueries.
523
524                    let mut input = get_outer;
525                    let old_arity = input.arity();
526
527                    let exprs = exprs
528                        .into_iter()
529                        .map(|e| e.applied_to(id_gen, col_map, cte_map, &mut input, &None, context))
530                        .collect::<Result<Vec<_>, _>>()?;
531
532                    let new_arity = input.arity();
533                    let output_arity = func.output_arity();
534                    input = input.flat_map(func, exprs);
535                    if old_arity != new_arity {
536                        // this means we added some columns to handle subqueries, and now we need to get rid of them
537                        input = input.project(
538                            (0..old_arity)
539                                .chain(new_arity..new_arity + output_arity)
540                                .collect(),
541                        );
542                    }
543                    input
544                }
545                Filter { input, predicates } => {
546                    // Filter expressions may contain correlated subqueries.
547                    // We extend `get_outer` with sufficient values to determine the value of the predicate,
548                    // then filter the results, then strip off any columns that were added for this purpose.
549                    let mut input =
550                        input.applied_to(id_gen, get_outer, col_map, cte_map, context)?;
551                    for predicate in predicates {
552                        let old_arity = input.arity();
553                        let predicate = predicate
554                            .applied_to(id_gen, col_map, cte_map, &mut input, &None, context)?;
555                        let new_arity = input.arity();
556                        input = input.filter(vec![predicate]);
557                        if old_arity != new_arity {
558                            // this means we added some columns to handle subqueries, and now we need to get rid of them
559                            input = input.project((0..old_arity).collect());
560                        }
561                    }
562                    input
563                }
564                Join {
565                    left,
566                    right,
567                    on,
568                    kind,
569                } if right.is_correlated() => {
570                    // A correlated join is a join in which the right expression has
571                    // access to the columns in the left expression. It turns out
572                    // this is *exactly* our branch operator, plus some additional
573                    // null handling in the case of left joins. (Right and full
574                    // lateral joins are not permitted.)
575                    //
576                    // As with normal joins, the `on` predicate may be correlated,
577                    // and we treat it as a filter that follows the branch.
578
579                    assert!(kind.can_be_correlated());
580
581                    let left = left.applied_to(id_gen, get_outer, col_map, cte_map, context)?;
582                    left.let_in(id_gen, |id_gen, get_left| {
583                        let apply_requires_distinct_outer = false;
584                        let mut join = branch(
585                            id_gen,
586                            get_left.clone(),
587                            col_map,
588                            cte_map,
589                            *right,
590                            apply_requires_distinct_outer,
591                            context,
592                            |id_gen, right, get_left, col_map, cte_map, context| {
593                                right.applied_to(id_gen, get_left, col_map, cte_map, context)
594                            },
595                        )?;
596
597                        // Plan the `on` predicate.
598                        let old_arity = join.arity();
599                        let on =
600                            on.applied_to(id_gen, col_map, cte_map, &mut join, &None, context)?;
601                        join = join.filter(vec![on]);
602                        let new_arity = join.arity();
603                        if old_arity != new_arity {
604                            // This means we added some columns to handle
605                            // subqueries, and now we need to get rid of them.
606                            join = join.project((0..old_arity).collect());
607                        }
608
609                        // If a left join, reintroduce any rows from the left that
610                        // are missing, with nulls filled in for the right columns.
611                        if let JoinKind::LeftOuter { .. } = kind {
612                            let default = join
613                                .typ()
614                                .column_types
615                                .into_iter()
616                                .skip(get_left.arity())
617                                .map(|typ| (Datum::Null, typ.scalar_type))
618                                .collect();
619                            get_left.lookup(id_gen, join, default)
620                        } else {
621                            Ok::<_, PlanError>(join)
622                        }
623                    })?
624                }
625                Join {
626                    left,
627                    right,
628                    on,
629                    kind,
630                } => {
631                    if context.config.enable_variadic_left_join_lowering {
632                        // Attempt to extract a stack of left joins.
633                        if let JoinKind::LeftOuter = kind {
634                            let mut rights = vec![(&*right, &on)];
635                            let mut left_test = &left;
636                            while let Join {
637                                left,
638                                right,
639                                on,
640                                kind: JoinKind::LeftOuter,
641                            } = &**left_test
642                            {
643                                rights.push((&**right, on));
644                                left_test = left;
645                            }
646                            if rights.len() > 1 {
647                                // Defensively clone `cte_map` as it may be mutated.
648                                let cte_map_clone = cte_map.clone();
649                                if let Ok(Some(magic)) = variadic_left::attempt_left_join_magic(
650                                    left_test,
651                                    rights,
652                                    id_gen,
653                                    get_outer.clone(),
654                                    col_map,
655                                    cte_map,
656                                    context,
657                                ) {
658                                    return Ok(magic);
659                                } else {
660                                    cte_map.clone_from(&cte_map_clone);
661                                }
662                            }
663                        }
664                    }
665
666                    // Both join expressions should be decorrelated, and then joined by their
667                    // leading columns to form only those pairs corresponding to the same row
668                    // of `get_outer`.
669                    //
670                    // The `on` predicate may contain correlated subqueries, and we treat it
671                    // as though it was a filter, with the caveat that we also translate outer
672                    // joins in this step. The post-filtration results need to be considered
673                    // against the records present in the left and right (decorrelated) inputs,
674                    // depending on the type of join.
675                    let oa = get_outer.arity();
676                    let left =
677                        left.applied_to(id_gen, get_outer.clone(), col_map, cte_map, context)?;
678                    let lt = left.typ().column_types.into_iter().skip(oa).collect_vec();
679                    let la = lt.len();
680                    left.let_in(id_gen, |id_gen, get_left| {
681                        let right_col_map = col_map.enter_scope(0);
682                        let right = right.applied_to(
683                            id_gen,
684                            get_outer.clone(),
685                            &right_col_map,
686                            cte_map,
687                            context,
688                        )?;
689                        let rt = right.typ().column_types.into_iter().skip(oa).collect_vec();
690                        let ra = rt.len();
691                        right.let_in(id_gen, |id_gen, get_right| {
692                            let mut product = SR::join(
693                                vec![get_left.clone(), get_right.clone()],
694                                (0..oa).map(|i| vec![(0, i), (1, i)]).collect(),
695                            )
696                            // Project away the repeated copy of get_outer's columns.
697                            .project(
698                                (0..(oa + la))
699                                    .chain((oa + la + oa)..(oa + la + oa + ra))
700                                    .collect(),
701                            );
702
703                            // Decorrelate and lower the `on` clause.
704                            let on = on.applied_to(
705                                id_gen,
706                                col_map,
707                                cte_map,
708                                &mut product,
709                                &None,
710                                context,
711                            )?;
712                            // Collect the types of all subqueries appearing in
713                            // the `on` clause. The subquery results were
714                            // appended to `product` in the `on.applied_to(...)`
715                            // call above.
716                            let on_subquery_types = product
717                                .typ()
718                                .column_types
719                                .drain(oa + la + ra..)
720                                .collect_vec();
721                            // Remember if `on` had any subqueries.
722                            let on_has_subqueries = !on_subquery_types.is_empty();
723
724                            // Attempt an efficient equijoin implementation, in which outer joins are
725                            // more efficiently rendered than in general. This can return `None` if
726                            // such a plan is not possible, for example if `on` does not describe an
727                            // equijoin between columns of `left` and `right`.
728                            if kind != JoinKind::Inner {
729                                if let Some(joined) = attempt_outer_equijoin(
730                                    get_left.clone(),
731                                    get_right.clone(),
732                                    on.clone(),
733                                    on_subquery_types,
734                                    kind.clone(),
735                                    oa,
736                                    id_gen,
737                                    context,
738                                )? {
739                                    if let Some(metrics) = context.metrics {
740                                        metrics.inc_outer_join_lowering("equi");
741                                    }
742                                    return Ok(joined);
743                                }
744                            }
745
746                            // Otherwise, perform a more general join.
747                            if let Some(metrics) = context.metrics {
748                                metrics.inc_outer_join_lowering("general");
749                            }
750                            let mut join = product.filter(vec![on]);
751                            if on_has_subqueries {
752                                // This means that `on.applied_to(...)` appended
753                                // some columns to handle subqueries, and now we
754                                // need to get rid of them.
755                                join = join.project((0..oa + la + ra).collect());
756                            }
757                            join.let_in(id_gen, |id_gen, get_join| {
758                                let mut result = get_join.clone();
759                                if let JoinKind::LeftOuter { .. } | JoinKind::FullOuter { .. } =
760                                    kind
761                                {
762                                    let left_outer = get_left.clone().anti_lookup::<PlanError>(
763                                        id_gen,
764                                        get_join.clone(),
765                                        rt.into_iter()
766                                            .map(|typ| (Datum::Null, typ.scalar_type))
767                                            .collect(),
768                                    )?;
769                                    result = result.union(left_outer);
770                                }
771                                if let JoinKind::RightOuter | JoinKind::FullOuter = kind {
772                                    let right_outer = get_right
773                                        .clone()
774                                        .anti_lookup::<PlanError>(
775                                            id_gen,
776                                            get_join
777                                                // need to swap left and right to make the anti_lookup work
778                                                .project(
779                                                    (0..oa)
780                                                        .chain((oa + la)..(oa + la + ra))
781                                                        .chain((oa)..(oa + la))
782                                                        .collect(),
783                                                ),
784                                            lt.into_iter()
785                                                .map(|typ| (Datum::Null, typ.scalar_type))
786                                                .collect(),
787                                        )?
788                                        // swap left and right back again
789                                        .project(
790                                            (0..oa)
791                                                .chain((oa + ra)..(oa + ra + la))
792                                                .chain((oa)..(oa + ra))
793                                                .collect(),
794                                        );
795                                    result = result.union(right_outer);
796                                }
797                                Ok::<MirRelationExpr, PlanError>(result)
798                            })
799                        })
800                    })?
801                }
802                Union { base, inputs } => {
803                    // Union is uncomplicated.
804                    SR::Union {
805                        base: Box::new(base.applied_to(
806                            id_gen,
807                            get_outer.clone(),
808                            col_map,
809                            cte_map,
810                            context,
811                        )?),
812                        inputs: inputs
813                            .into_iter()
814                            .map(|input| {
815                                input.applied_to(
816                                    id_gen,
817                                    get_outer.clone(),
818                                    col_map,
819                                    cte_map,
820                                    context,
821                                )
822                            })
823                            .collect::<Result<Vec<_>, _>>()?,
824                    }
825                }
826                Reduce {
827                    input,
828                    group_key,
829                    aggregates,
830                    expected_group_size,
831                } => {
832                    // Reduce may contain expressions with correlated subqueries.
833                    // In addition, here an empty reduction key signifies that we need to supply default values
834                    // in the case that there are no results (as in a SQL aggregation without an explicit GROUP BY).
835                    let mut input =
836                        input.applied_to(id_gen, get_outer.clone(), col_map, cte_map, context)?;
837                    let applied_group_key = (0..get_outer.arity())
838                        .chain(group_key.iter().map(|i| get_outer.arity() + i))
839                        .collect();
840                    let applied_aggregates = aggregates
841                        .into_iter()
842                        .map(|aggregate| {
843                            aggregate.applied_to(id_gen, col_map, cte_map, &mut input, context)
844                        })
845                        .collect::<Result<Vec<_>, _>>()?;
846                    let input_type = input.typ();
847                    let default = applied_aggregates
848                        .iter()
849                        .map(|agg| {
850                            (
851                                agg.func.default(),
852                                agg.typ(&input_type.column_types).scalar_type,
853                            )
854                        })
855                        .collect();
856                    // NOTE we don't need to remove any extra columns from aggregate.applied_to above because the reduce will do that anyway
857                    let mut reduced =
858                        input.reduce(applied_group_key, applied_aggregates, expected_group_size);
859
860                    // Introduce default values in the case the group key is empty.
861                    if group_key.is_empty() {
862                        reduced = get_outer.lookup::<PlanError>(id_gen, reduced, default)?;
863                    }
864                    reduced
865                }
866                Distinct { input } => {
867                    // Distinct is uncomplicated.
868                    input
869                        .applied_to(id_gen, get_outer, col_map, cte_map, context)?
870                        .distinct()
871                }
872                TopK {
873                    input,
874                    group_key,
875                    order_key,
876                    limit,
877                    offset,
878                    expected_group_size,
879                } => {
880                    // TopK is uncomplicated, except that we must group by the columns of `get_outer` as well.
881                    let mut input =
882                        input.applied_to(id_gen, get_outer.clone(), col_map, cte_map, context)?;
883                    let mut applied_group_key: Vec<_> = (0..get_outer.arity())
884                        .chain(group_key.iter().map(|i| get_outer.arity() + i))
885                        .collect();
886                    let applied_order_key = order_key
887                        .iter()
888                        .map(|column_order| ColumnOrder {
889                            column: column_order.column + get_outer.arity(),
890                            desc: column_order.desc,
891                            nulls_last: column_order.nulls_last,
892                        })
893                        .collect();
894
895                    let old_arity = input.arity();
896
897                    // Lower `limit`, which may introduce new columns if is a correlated subquery.
898                    let mut limit_mir = None;
899                    if let Some(limit) = limit {
900                        limit_mir = Some(
901                            limit
902                                .applied_to(id_gen, col_map, cte_map, &mut input, &None, context)?,
903                        );
904                    }
905
906                    let new_arity = input.arity();
907                    // Extend the key to contain any new columns.
908                    applied_group_key.extend(old_arity..new_arity);
909
910                    let offset = offset
911                        .try_into_literal_int64()
912                        .expect("Should be a Literal by this time")
913                        .try_into()
914                        .expect("Should have checked non-negativity of OFFSET clause already");
915                    let mut result = input.top_k(
916                        applied_group_key,
917                        applied_order_key,
918                        limit_mir,
919                        offset,
920                        expected_group_size,
921                    );
922
923                    // If new columns were added for `limit` we must remove them.
924                    if old_arity != new_arity {
925                        result = result.project((0..old_arity).collect());
926                    }
927
928                    result
929                }
930                Negate { input } => {
931                    // Negate is uncomplicated.
932                    input
933                        .applied_to(id_gen, get_outer, col_map, cte_map, context)?
934                        .negate()
935                }
936                Threshold { input } => {
937                    // Threshold is uncomplicated.
938                    input
939                        .applied_to(id_gen, get_outer, col_map, cte_map, context)?
940                        .threshold()
941                }
942            })
943        })
944    }
945}
946
947impl HirScalarExpr {
948    /// Rewrite `self` into a `mz_expr::ScalarExpr` which can be applied to the modified `inner`.
949    ///
950    /// This method is responsible for decorrelating subqueries in `self` by introducing further columns
951    /// to `inner`, and rewriting `self` to refer to its physical columns (specified by `usize` positions).
952    /// The most complicated logic is for the scalar expressions that involve subqueries, each of which are
953    /// documented in more detail closer to their logic.
954    ///
955    /// This process presumes that `inner` is the result of decorrelation, meaning its first several columns
956    /// may be inherited from outer relations. The `col_map` column map should provide specific offsets where
957    /// each of these references can be found.
958    fn applied_to(
959        self,
960        id_gen: &mut mz_ore::id_gen::IdGen,
961        col_map: &ColumnMap,
962        cte_map: &mut CteMap,
963        inner: &mut MirRelationExpr,
964        subquery_map: &Option<&BTreeMap<HirScalarExpr, usize>>,
965        context: &Context,
966    ) -> Result<MirScalarExpr, PlanError> {
967        maybe_grow(|| {
968            use MirScalarExpr as SS;
969
970            use HirScalarExpr::*;
971
972            if let Some(subquery_map) = subquery_map {
973                if let Some(col) = subquery_map.get(&self) {
974                    return Ok(SS::column(*col));
975                }
976            }
977
978            Ok::<MirScalarExpr, PlanError>(match self {
979                Column(col_ref, name) => SS::Column(col_map.get(&col_ref), name),
980                Literal(row, typ, _name) => SS::Literal(Ok(row), typ),
981                Parameter(_, _name) => {
982                    panic!("cannot decorrelate expression with unbound parameters")
983                }
984                CallUnmaterializable(func, _name) => SS::CallUnmaterializable(func),
985                CallUnary {
986                    func: func::UnaryFunc::CastVarCharToString(_),
987                    expr,
988                    name: _,
989                } if context.config.enable_cast_elimination => {
990                    expr.applied_to(id_gen, col_map, cte_map, inner, subquery_map, context)?
991                }
992                CallUnary {
993                    func,
994                    expr,
995                    name: _,
996                } => SS::CallUnary {
997                    func,
998                    expr: Box::new(expr.applied_to(
999                        id_gen,
1000                        col_map,
1001                        cte_map,
1002                        inner,
1003                        subquery_map,
1004                        context,
1005                    )?),
1006                },
1007                CallBinary {
1008                    func,
1009                    expr1,
1010                    expr2,
1011                    name: _,
1012                } => SS::CallBinary {
1013                    func,
1014                    expr1: Box::new(expr1.applied_to(
1015                        id_gen,
1016                        col_map,
1017                        cte_map,
1018                        inner,
1019                        subquery_map,
1020                        context,
1021                    )?),
1022                    expr2: Box::new(expr2.applied_to(
1023                        id_gen,
1024                        col_map,
1025                        cte_map,
1026                        inner,
1027                        subquery_map,
1028                        context,
1029                    )?),
1030                },
1031                CallVariadic {
1032                    func,
1033                    exprs,
1034                    name: _,
1035                } => SS::CallVariadic {
1036                    func,
1037                    exprs: exprs
1038                        .into_iter()
1039                        .map(|expr| {
1040                            expr.applied_to(id_gen, col_map, cte_map, inner, subquery_map, context)
1041                        })
1042                        .collect::<Result<Vec<_>, _>>()?,
1043                },
1044                If {
1045                    cond,
1046                    then,
1047                    els,
1048                    name,
1049                } => {
1050                    // The `If` case is complicated by the fact that we do not want to
1051                    // apply the `then` or `else` logic to tuples that respectively do
1052                    // not or do pass the `cond` test. Our strategy is to independently
1053                    // decorrelate the `then` and `else` logic, and apply each to tuples
1054                    // that respectively pass and do not pass the `cond` logic (which is
1055                    // executed, and so decorrelated, for all tuples).
1056                    //
1057                    // Informally, we turn the `if` statement into:
1058                    //
1059                    //   let then_case = inner.filter(cond).map(then);
1060                    //   let else_case = inner.filter(!cond).map(else);
1061                    //   return then_case.concat(else_case);
1062                    //
1063                    // We only require this if either expression would result in any
1064                    // computation beyond the expr itself, which we will interpret as
1065                    // "introduces additional columns". In the absence of correlation,
1066                    // we should just retain a `ScalarExpr::If` expression; the inverse
1067                    // transformation as above is complicated to recover after the fact,
1068                    // and we would benefit from not introducing the complexity.
1069
1070                    let inner_arity = inner.arity();
1071                    let cond_expr =
1072                        cond.applied_to(id_gen, col_map, cte_map, inner, subquery_map, context)?;
1073
1074                    // Defensive copies, in case we mangle these in decorrelation.
1075                    let inner_clone = inner.clone();
1076                    let then_clone = then.clone();
1077                    let else_clone = els.clone();
1078
1079                    let cond_arity = inner.arity();
1080                    let then_expr =
1081                        then.applied_to(id_gen, col_map, cte_map, inner, subquery_map, context)?;
1082                    let else_expr =
1083                        els.applied_to(id_gen, col_map, cte_map, inner, subquery_map, context)?;
1084
1085                    if cond_arity == inner.arity() {
1086                        // If no additional columns were added, we simply return the
1087                        // `If` variant with the updated expressions.
1088                        SS::If {
1089                            cond: Box::new(cond_expr),
1090                            then: Box::new(then_expr),
1091                            els: Box::new(else_expr),
1092                        }
1093                    } else {
1094                        // If columns were added, we need a more careful approach, as
1095                        // described above. First, we need to de-correlate each of
1096                        // the two expressions independently, and apply their cases
1097                        // as `MirRelationExpr::Map` operations.
1098
1099                        *inner = inner_clone.let_in(id_gen, |id_gen, get_inner| {
1100                            // Restrict to records satisfying `cond_expr` and apply `then` as a map.
1101                            let mut then_inner = get_inner.clone().filter(vec![cond_expr.clone()]);
1102                            let then_expr = then_clone.applied_to(
1103                                id_gen,
1104                                col_map,
1105                                cte_map,
1106                                &mut then_inner,
1107                                subquery_map,
1108                                context,
1109                            )?;
1110                            let then_arity = then_inner.arity();
1111                            then_inner = then_inner
1112                                .map_one(then_expr)
1113                                .project((0..inner_arity).chain(Some(then_arity)).collect());
1114
1115                            // Restrict to records not satisfying `cond_expr` and apply `els` as a map.
1116                            let mut else_inner = get_inner.filter(vec![SS::CallVariadic {
1117                                func: mz_expr::VariadicFunc::Or,
1118                                exprs: vec![
1119                                    cond_expr.clone().call_binary(SS::literal_false(), func::Eq),
1120                                    cond_expr.clone().call_is_null(),
1121                                ],
1122                            }]);
1123                            let else_expr = else_clone.applied_to(
1124                                id_gen,
1125                                col_map,
1126                                cte_map,
1127                                &mut else_inner,
1128                                subquery_map,
1129                                context,
1130                            )?;
1131                            let else_arity = else_inner.arity();
1132                            else_inner = else_inner
1133                                .map_one(else_expr)
1134                                .project((0..inner_arity).chain(Some(else_arity)).collect());
1135
1136                            // concatenate the two results.
1137                            Ok::<MirRelationExpr, PlanError>(then_inner.union(else_inner))
1138                        })?;
1139
1140                        SS::Column(inner_arity, name)
1141                    }
1142                }
1143
1144                // Subqueries!
1145                // These are surprisingly subtle. Things to be careful of:
1146
1147                // Anything in the subquery that cares about row counts (Reduce/Distinct/Negate/Threshold) must not:
1148                // * change the row counts of the outer query
1149                // * accidentally compute its own value using the row counts of the outer query
1150                // Use `branch` to calculate the subquery once for each __distinct__ key in the outer
1151                // query and then join the answers back on to the original rows of the outer query.
1152
1153                // When the subquery would return 0 rows for some row in the outer query, `subquery.applied_to(get_inner)` will not have any corresponding row.
1154                // Use `lookup` if you need to add default values for cases when the subquery returns 0 rows.
1155                Exists(expr, name) => {
1156                    let apply_requires_distinct_outer = true;
1157                    *inner = apply_existential_subquery(
1158                        id_gen,
1159                        inner.take_dangerous(),
1160                        col_map,
1161                        cte_map,
1162                        *expr,
1163                        apply_requires_distinct_outer,
1164                        context,
1165                    )?;
1166                    SS::Column(inner.arity() - 1, name)
1167                }
1168
1169                Select(expr, name) => {
1170                    let apply_requires_distinct_outer = true;
1171                    *inner = apply_scalar_subquery(
1172                        id_gen,
1173                        inner.take_dangerous(),
1174                        col_map,
1175                        cte_map,
1176                        *expr,
1177                        apply_requires_distinct_outer,
1178                        context,
1179                    )?;
1180                    SS::Column(inner.arity() - 1, name)
1181                }
1182                Windowing(expr, _name) => {
1183                    let partition_by = expr.partition_by;
1184                    let order_by = expr.order_by;
1185
1186                    // argument lowering for scalar window functions
1187                    // (We need to specify the & _ in the arguments because of this problem:
1188                    // https://users.rust-lang.org/t/the-implementation-of-fnonce-is-not-general-enough/72141/3 )
1189                    let scalar_lower_args =
1190                        |_id_gen: &mut _,
1191                         _col_map: &_,
1192                         _cte_map: &mut _,
1193                         _get_inner: &mut _,
1194                         _subquery_map: &Option<&_>,
1195                         order_by_mir: Vec<MirScalarExpr>,
1196                         original_row_record,
1197                         original_row_record_type: SqlScalarType| {
1198                            let agg_input = MirScalarExpr::CallVariadic {
1199                                func: mz_expr::VariadicFunc::ListCreate {
1200                                    elem_type: original_row_record_type.clone(),
1201                                },
1202                                exprs: vec![original_row_record],
1203                            };
1204                            let mut agg_input = vec![agg_input];
1205                            agg_input.extend(order_by_mir.clone());
1206                            let agg_input = MirScalarExpr::CallVariadic {
1207                                func: mz_expr::VariadicFunc::RecordCreate {
1208                                    field_names: (0..agg_input.len())
1209                                        .map(|_| ColumnName::from(UNKNOWN_COLUMN_NAME))
1210                                        .collect_vec(),
1211                                },
1212                                exprs: agg_input,
1213                            };
1214                            let list_type = SqlScalarType::List {
1215                                element_type: Box::new(original_row_record_type),
1216                                custom_id: None,
1217                            };
1218                            let agg_input_type = SqlScalarType::Record {
1219                                fields: std::iter::once(&list_type)
1220                                    .map(|t| {
1221                                        (
1222                                            ColumnName::from(UNKNOWN_COLUMN_NAME),
1223                                            t.clone().nullable(false),
1224                                        )
1225                                    })
1226                                    .collect(),
1227                                custom_id: None,
1228                            }
1229                            .nullable(false);
1230
1231                            Ok((agg_input, agg_input_type))
1232                        };
1233
1234                    // argument lowering for value window functions and aggregate window functions
1235                    let value_or_aggr_lower_args = |hir_encoded_args: Box<HirScalarExpr>| {
1236                        |id_gen: &mut _,
1237                         col_map: &_,
1238                         cte_map: &mut _,
1239                         get_inner: &mut _,
1240                         subquery_map: &Option<&_>,
1241                         order_by_mir: Vec<MirScalarExpr>,
1242                         original_row_record,
1243                         original_row_record_type| {
1244                            // Creates [((OriginalRow, EncodedArgs), OrderByExprs...)]
1245
1246                            // Compute the encoded args for all rows
1247                            let mir_encoded_args = hir_encoded_args.applied_to(
1248                                id_gen,
1249                                col_map,
1250                                cte_map,
1251                                get_inner,
1252                                subquery_map,
1253                                context,
1254                            )?;
1255                            let mir_encoded_args_type = mir_encoded_args
1256                                .typ(&get_inner.typ().column_types)
1257                                .scalar_type;
1258
1259                            // Build a new record that has two fields:
1260                            // 1. the original row in a record
1261                            // 2. the encoded args (which can be either a single value, or a record
1262                            //    if the window function has multiple arguments, such as `lag`)
1263                            let fn_input_record_fields: Box<[_]> =
1264                                [original_row_record_type, mir_encoded_args_type]
1265                                    .iter()
1266                                    .map(|t| {
1267                                        (
1268                                            ColumnName::from(UNKNOWN_COLUMN_NAME),
1269                                            t.clone().nullable(false),
1270                                        )
1271                                    })
1272                                    .collect();
1273                            let fn_input_record = MirScalarExpr::CallVariadic {
1274                                func: mz_expr::VariadicFunc::RecordCreate {
1275                                    field_names: fn_input_record_fields
1276                                        .iter()
1277                                        .map(|(n, _)| n.clone())
1278                                        .collect_vec(),
1279                                },
1280                                exprs: vec![original_row_record, mir_encoded_args],
1281                            };
1282                            let fn_input_record_type = SqlScalarType::Record {
1283                                fields: fn_input_record_fields,
1284                                custom_id: None,
1285                            }
1286                            .nullable(false);
1287
1288                            // Build a new record with the record above + the ORDER BY exprs
1289                            // This follows the standard encoding of ORDER BY exprs used by aggregate functions
1290                            let mut agg_input = vec![fn_input_record];
1291                            agg_input.extend(order_by_mir.clone());
1292                            let agg_input = MirScalarExpr::CallVariadic {
1293                                func: mz_expr::VariadicFunc::RecordCreate {
1294                                    field_names: (0..agg_input.len())
1295                                        .map(|_| ColumnName::from(UNKNOWN_COLUMN_NAME))
1296                                        .collect_vec(),
1297                                },
1298                                exprs: agg_input,
1299                            };
1300
1301                            let agg_input_type = SqlScalarType::Record {
1302                                fields: [(
1303                                    ColumnName::from(UNKNOWN_COLUMN_NAME),
1304                                    fn_input_record_type.nullable(false),
1305                                )]
1306                                .into(),
1307                                custom_id: None,
1308                            }
1309                            .nullable(false);
1310
1311                            Ok((agg_input, agg_input_type))
1312                        }
1313                    };
1314
1315                    match expr.func {
1316                        WindowExprType::Scalar(scalar_window_expr) => {
1317                            let mir_aggr_func = scalar_window_expr.into_expr();
1318                            Self::window_func_applied_to(
1319                                id_gen,
1320                                col_map,
1321                                cte_map,
1322                                inner,
1323                                subquery_map,
1324                                partition_by,
1325                                order_by,
1326                                mir_aggr_func,
1327                                scalar_lower_args,
1328                                context,
1329                            )?
1330                        }
1331                        WindowExprType::Value(value_window_expr) => {
1332                            let (hir_encoded_args, mir_aggr_func) = value_window_expr.into_expr();
1333
1334                            Self::window_func_applied_to(
1335                                id_gen,
1336                                col_map,
1337                                cte_map,
1338                                inner,
1339                                subquery_map,
1340                                partition_by,
1341                                order_by,
1342                                mir_aggr_func,
1343                                value_or_aggr_lower_args(hir_encoded_args),
1344                                context,
1345                            )?
1346                        }
1347                        WindowExprType::Aggregate(aggr_window_expr) => {
1348                            let (hir_encoded_args, mir_aggr_func) = aggr_window_expr.into_expr();
1349
1350                            Self::window_func_applied_to(
1351                                id_gen,
1352                                col_map,
1353                                cte_map,
1354                                inner,
1355                                subquery_map,
1356                                partition_by,
1357                                order_by,
1358                                mir_aggr_func,
1359                                value_or_aggr_lower_args(hir_encoded_args),
1360                                context,
1361                            )?
1362                        }
1363                    }
1364                }
1365            })
1366        })
1367    }
1368
1369    fn window_func_applied_to<F>(
1370        id_gen: &mut mz_ore::id_gen::IdGen,
1371        col_map: &ColumnMap,
1372        cte_map: &mut CteMap,
1373        inner: &mut MirRelationExpr,
1374        subquery_map: &Option<&BTreeMap<HirScalarExpr, usize>>,
1375        partition_by: Vec<HirScalarExpr>,
1376        order_by: Vec<HirScalarExpr>,
1377        mir_aggr_func: AggregateFunc,
1378        lower_args: F,
1379        context: &Context,
1380    ) -> Result<MirScalarExpr, PlanError>
1381    where
1382        F: FnOnce(
1383            &mut mz_ore::id_gen::IdGen,
1384            &ColumnMap,
1385            &mut CteMap,
1386            &mut MirRelationExpr,
1387            &Option<&BTreeMap<HirScalarExpr, usize>>,
1388            Vec<MirScalarExpr>,
1389            MirScalarExpr,
1390            SqlScalarType,
1391        ) -> Result<(MirScalarExpr, SqlColumnType), PlanError>,
1392    {
1393        // Example MIRs for a window function (specifically, a window aggregation):
1394        //
1395        // CREATE TABLE t7(x INT, y INT);
1396        //
1397        // explain decorrelated plan for select sum(x*y) over (partition by x+y order by x-y, x/y) from t7;
1398        //
1399        // Decorrelated Plan
1400        // Project (#3)
1401        //   Map (#2)
1402        //     Project (#3..=#5)
1403        //       Map (record_get[0](record_get[1](#2)), record_get[1](record_get[1](#2)), record_get[0](#2))
1404        //         FlatMap unnest_list(#1)
1405        //           Reduce group_by=[#2] aggregates=[window_agg[sum order_by=[#0 asc nulls_last, #1 asc nulls_last]](row(row(row(#0, #1), (#0 * #1)), (#0 - #1), (#0 / #1)))]
1406        //             Map ((#0 + #1))
1407        //               CrossJoin
1408        //                 Constant
1409        //                   - ()
1410        //                 Get materialize.public.t7
1411        //
1412        // The same query after optimizations:
1413        //
1414        // explain select sum(x*y) over (partition by x+y order by x-y, x/y) from t7;
1415        //
1416        // Optimized Plan
1417        // Explained Query:
1418        //   Project (#2)
1419        //     Map (record_get[0](#1))
1420        //       FlatMap unnest_list(#0)
1421        //         Project (#1)
1422        //           Reduce group_by=[(#0 + #1)] aggregates=[window_agg[sum order_by=[#0 asc nulls_last, #1 asc nulls_last]](row(row(row(#0, #1), (#0 * #1)), (#0 - #1), (#0 / #1)))]
1423        //             ReadStorage materialize.public.t7
1424        //
1425        // The `row(row(row(...), ...), ...)` stuff means the following:
1426        // `row(row(row(<original row>), <arguments to window function>), <order by values>...)`
1427        //   - The <arguments to window function> can be either a single value or itself a
1428        //     `row` if there are multiple arguments.
1429        //   - The <order by values> are _not_ wrapped in a `row`, even if there are more than one
1430        //     ORDER BY columns.
1431        //   - The <original row> currently always captures the entire original row. This should
1432        //     improve when we make `ProjectionPushdown` smarter, see
1433        //     https://github.com/MaterializeInc/database-issues/issues/5090
1434        //
1435        // TODO:
1436        // We should probably introduce some dedicated Datum constructor functions instead of `row`
1437        // to make MIR plans and MIR construction/manipulation code more readable. Additionally, we
1438        // might even introduce dedicated Datum enum variants, so that the rendering code also
1439        // becomes more readable (and possibly slightly more performant).
1440
1441        *inner = inner
1442            .take_dangerous()
1443            .let_in(id_gen, |id_gen, mut get_inner| {
1444                let order_by_mir = order_by
1445                    .into_iter()
1446                    .map(|o| {
1447                        o.applied_to(
1448                            id_gen,
1449                            col_map,
1450                            cte_map,
1451                            &mut get_inner,
1452                            subquery_map,
1453                            context,
1454                        )
1455                    })
1456                    .collect::<Result<Vec<_>, _>>()?;
1457
1458                // Record input arity here so that any group_keys that need to mutate get_inner
1459                // don't add those columns to the aggregate input.
1460                let input_type = get_inner.typ();
1461                let input_arity = input_type.arity();
1462                // The reduction that computes the window function must be keyed on the columns
1463                // from the outer context, plus the expressions in the partition key. The current
1464                // subquery will be 'executed' for every distinct row from the outer context so
1465                // by putting the outer columns in the grouping key we isolate each re-execution.
1466                let mut group_key = col_map
1467                    .inner
1468                    .iter()
1469                    .map(|(_, outer_col)| *outer_col)
1470                    .sorted()
1471                    .collect_vec();
1472                for p in partition_by {
1473                    let key = p.applied_to(
1474                        id_gen,
1475                        col_map,
1476                        cte_map,
1477                        &mut get_inner,
1478                        subquery_map,
1479                        context,
1480                    )?;
1481                    if let MirScalarExpr::Column(c, _name) = key {
1482                        group_key.push(c);
1483                    } else {
1484                        get_inner = get_inner.map_one(key);
1485                        group_key.push(get_inner.arity() - 1);
1486                    }
1487                }
1488
1489                get_inner.let_in(id_gen, |id_gen, mut get_inner| {
1490                    // Original columns of the relation
1491                    let fields: Box<_> = input_type
1492                        .column_types
1493                        .iter()
1494                        .take(input_arity)
1495                        .map(|t| (ColumnName::from(UNKNOWN_COLUMN_NAME), t.clone()))
1496                        .collect();
1497
1498                    // Original row made into a record
1499                    let original_row_record = MirScalarExpr::CallVariadic {
1500                        func: mz_expr::VariadicFunc::RecordCreate {
1501                            field_names: fields.iter().map(|(name, _)| name.clone()).collect_vec(),
1502                        },
1503                        exprs: (0..input_arity).map(MirScalarExpr::column).collect_vec(),
1504                    };
1505                    let original_row_record_type = SqlScalarType::Record {
1506                        fields,
1507                        custom_id: None,
1508                    };
1509
1510                    let (agg_input, agg_input_type) = lower_args(
1511                        id_gen,
1512                        col_map,
1513                        cte_map,
1514                        &mut get_inner,
1515                        subquery_map,
1516                        order_by_mir,
1517                        original_row_record,
1518                        original_row_record_type,
1519                    )?;
1520
1521                    let aggregate = mz_expr::AggregateExpr {
1522                        func: mir_aggr_func,
1523                        expr: agg_input,
1524                        distinct: false,
1525                    };
1526
1527                    // Actually call reduce with the window function
1528                    // The output of the aggregation function should be a list of tuples that has
1529                    // the result in the first position, and the original row in the second position
1530                    let mut reduce = get_inner
1531                        .reduce(group_key.clone(), vec![aggregate.clone()], None)
1532                        .flat_map(
1533                            mz_expr::TableFunc::UnnestList {
1534                                el_typ: aggregate
1535                                    .func
1536                                    .output_type(agg_input_type)
1537                                    .scalar_type
1538                                    .unwrap_list_element_type()
1539                                    .clone(),
1540                            },
1541                            vec![MirScalarExpr::column(group_key.len())],
1542                        );
1543                    let record_col = reduce.arity() - 1;
1544
1545                    // Unpack the record output by the window function
1546                    for c in 0..input_arity {
1547                        reduce = reduce.take_dangerous().map_one(MirScalarExpr::CallUnary {
1548                            func: mz_expr::UnaryFunc::RecordGet(mz_expr::func::RecordGet(c)),
1549                            expr: Box::new(MirScalarExpr::CallUnary {
1550                                func: mz_expr::UnaryFunc::RecordGet(mz_expr::func::RecordGet(1)),
1551                                expr: Box::new(MirScalarExpr::column(record_col)),
1552                            }),
1553                        });
1554                    }
1555
1556                    // Append the column with the result of the window function.
1557                    reduce = reduce.take_dangerous().map_one(MirScalarExpr::CallUnary {
1558                        func: mz_expr::UnaryFunc::RecordGet(mz_expr::func::RecordGet(0)),
1559                        expr: Box::new(MirScalarExpr::column(record_col)),
1560                    });
1561
1562                    let agg_col = record_col + 1 + input_arity;
1563                    Ok::<_, PlanError>(reduce.project((record_col + 1..agg_col + 1).collect_vec()))
1564                })
1565            })?;
1566        Ok(MirScalarExpr::column(inner.arity() - 1))
1567    }
1568
1569    /// Applies the subqueries in the given list of scalar expressions to every distinct
1570    /// value of the given relation and returns a join of the given relation with all
1571    /// the subqueries found, and the mapping of scalar expressions with columns projected
1572    /// by the returned join that will hold their results.
1573    fn lower_subqueries(
1574        exprs: &[Self],
1575        id_gen: &mut mz_ore::id_gen::IdGen,
1576        col_map: &ColumnMap,
1577        cte_map: &mut CteMap,
1578        inner: MirRelationExpr,
1579        context: &Context,
1580    ) -> Result<(MirRelationExpr, BTreeMap<HirScalarExpr, usize>), PlanError> {
1581        let mut subquery_map = BTreeMap::new();
1582        let output = inner.let_in(id_gen, |id_gen, get_inner| {
1583            let mut subqueries = Vec::new();
1584            let distinct_inner = get_inner.clone().distinct();
1585            for expr in exprs.iter() {
1586                expr.visit_pre_post(
1587                    &mut |e| match e {
1588                        // For simplicity, subqueries within a conditional statement will be
1589                        // lowered when lowering the conditional expression.
1590                        HirScalarExpr::If { .. } => Some(vec![]),
1591                        _ => None,
1592                    },
1593                    &mut |e| match e {
1594                        HirScalarExpr::Select(expr, _name) => {
1595                            let apply_requires_distinct_outer = false;
1596                            let subquery = apply_scalar_subquery(
1597                                id_gen,
1598                                distinct_inner.clone(),
1599                                col_map,
1600                                cte_map,
1601                                (**expr).clone(),
1602                                apply_requires_distinct_outer,
1603                                context,
1604                            )
1605                            .unwrap();
1606
1607                            subqueries.push((e.clone(), subquery));
1608                        }
1609                        HirScalarExpr::Exists(expr, _name) => {
1610                            let apply_requires_distinct_outer = false;
1611                            let subquery = apply_existential_subquery(
1612                                id_gen,
1613                                distinct_inner.clone(),
1614                                col_map,
1615                                cte_map,
1616                                (**expr).clone(),
1617                                apply_requires_distinct_outer,
1618                                context,
1619                            )
1620                            .unwrap();
1621                            subqueries.push((e.clone(), subquery));
1622                        }
1623                        _ => {}
1624                    },
1625                )?;
1626            }
1627
1628            if subqueries.is_empty() {
1629                Ok::<MirRelationExpr, PlanError>(get_inner)
1630            } else {
1631                let inner_arity = get_inner.arity();
1632                let mut total_arity = inner_arity;
1633                let mut join_inputs = vec![get_inner];
1634                let mut join_input_arities = vec![inner_arity];
1635                for (expr, subquery) in subqueries.into_iter() {
1636                    // Avoid lowering duplicated subqueries
1637                    if !subquery_map.contains_key(&expr) {
1638                        let subquery_arity = subquery.arity();
1639                        assert_eq!(subquery_arity, inner_arity + 1);
1640                        join_inputs.push(subquery);
1641                        join_input_arities.push(subquery_arity);
1642                        total_arity += subquery_arity;
1643
1644                        // Column with the value of the subquery
1645                        subquery_map.insert(expr, total_arity - 1);
1646                    }
1647                }
1648                // Each subquery projects all the columns of the outer context (distinct_inner)
1649                // plus 1 column, containing the result of the subquery. Those columns must be
1650                // joined with the outer/main relation (get_inner).
1651                let input_mapper =
1652                    mz_expr::JoinInputMapper::new_from_input_arities(join_input_arities);
1653                let equivalences = (0..inner_arity)
1654                    .map(|col| {
1655                        join_inputs
1656                            .iter()
1657                            .enumerate()
1658                            .map(|(input, _)| {
1659                                MirScalarExpr::column(input_mapper.map_column_to_global(col, input))
1660                            })
1661                            .collect_vec()
1662                    })
1663                    .collect_vec();
1664                Ok(MirRelationExpr::join_scalars(join_inputs, equivalences))
1665            }
1666        })?;
1667        Ok((output, subquery_map))
1668    }
1669
1670    /// Rewrites `self` into a `mz_expr::ScalarExpr`.
1671    ///
1672    /// Returns an _internal_ error if the expression contains
1673    /// - a subquery
1674    /// - a column reference to an outer level
1675    /// - a parameter
1676    /// - a window function call
1677    ///
1678    /// Should succeed if [`HirScalarExpr::is_constant`] would return true on `self`.
1679    ///
1680    /// Set `enable_cast_elimination` to remove casts that are noops in MIR.
1681    pub fn lower_uncorrelated<C: Into<Config>>(
1682        self,
1683        config: C,
1684    ) -> Result<MirScalarExpr, PlanError> {
1685        let config = config.into();
1686
1687        use MirScalarExpr as SS;
1688
1689        use HirScalarExpr::*;
1690
1691        Ok(match self {
1692            Column(ColumnRef { level: 0, column }, name) => SS::Column(column, name),
1693            Literal(datum, typ, _name) => SS::Literal(Ok(datum), typ),
1694            CallUnmaterializable(func, _name) => SS::CallUnmaterializable(func),
1695            CallUnary {
1696                func: func::UnaryFunc::CastVarCharToString(_),
1697                expr,
1698                name: _,
1699            } if config.enable_cast_elimination => expr.lower_uncorrelated(config)?,
1700            CallUnary {
1701                func,
1702                expr,
1703                name: _,
1704            } => SS::CallUnary {
1705                func,
1706                expr: Box::new(expr.lower_uncorrelated(config)?),
1707            },
1708            CallBinary {
1709                func,
1710                expr1,
1711                expr2,
1712                name: _,
1713            } => SS::CallBinary {
1714                func,
1715                expr1: Box::new(expr1.lower_uncorrelated(config)?),
1716                expr2: Box::new(expr2.lower_uncorrelated(config)?),
1717            },
1718            CallVariadic {
1719                func,
1720                exprs,
1721                name: _,
1722            } => SS::CallVariadic {
1723                func,
1724                exprs: exprs
1725                    .into_iter()
1726                    .map(|expr| expr.lower_uncorrelated(config))
1727                    .collect::<Result<_, _>>()?,
1728            },
1729            If {
1730                cond,
1731                then,
1732                els,
1733                name: _,
1734            } => SS::If {
1735                cond: Box::new(cond.lower_uncorrelated(config)?),
1736                then: Box::new(then.lower_uncorrelated(config)?),
1737                els: Box::new(els.lower_uncorrelated(config)?),
1738            },
1739            Select { .. } | Exists { .. } | Parameter(..) | Column(..) | Windowing(..) => {
1740                sql_bail!(
1741                    "Internal error: unexpected HirScalarExpr in lower_uncorrelated: {:?}",
1742                    self
1743                );
1744            }
1745        })
1746    }
1747}
1748
1749/// Prepare to apply `inner` to `outer`. Note that `inner` is a correlated (SQL)
1750/// expression, while `outer` is a non-correlated (dataflow) expression. `inner`
1751/// will, in effect, be executed once for every distinct row in `outer`, and the
1752/// results will be joined with `outer`. Note that columns in `outer` that are
1753/// not depended upon by `inner` are thrown away before the distinct, so that we
1754/// don't perform needless computation of `inner`.
1755///
1756/// `branch` will inspect the contents of `inner` to determine whether `inner`
1757/// is not multiplicity sensitive (roughly, contains only maps, filters,
1758/// projections, and calls to table functions). If it is not multiplicity
1759/// sensitive, `branch` will *not* distinctify outer. If this is problematic,
1760/// e.g. because the `apply` callback itself introduces multiplicity-sensitive
1761/// operations that were not present in `inner`, then set
1762/// `apply_requires_distinct_outer` to ensure that `branch` chooses the plan
1763/// that distinctifies `outer`.
1764///
1765/// The caller must supply the `apply` function that applies the rewritten
1766/// `inner` to `outer`.
1767fn branch<F>(
1768    id_gen: &mut mz_ore::id_gen::IdGen,
1769    outer: MirRelationExpr,
1770    col_map: &ColumnMap,
1771    cte_map: &mut CteMap,
1772    inner: HirRelationExpr,
1773    apply_requires_distinct_outer: bool,
1774    context: &Context,
1775    apply: F,
1776) -> Result<MirRelationExpr, PlanError>
1777where
1778    F: FnOnce(
1779        &mut mz_ore::id_gen::IdGen,
1780        HirRelationExpr,
1781        MirRelationExpr,
1782        &ColumnMap,
1783        &mut CteMap,
1784        &Context,
1785    ) -> Result<MirRelationExpr, PlanError>,
1786{
1787    // TODO: It would be nice to have a version of this code w/o optimizations,
1788    // at the least for purposes of understanding. It was difficult for one reader
1789    // to understand the required properties of `outer` and `col_map`.
1790
1791    // If the inner expression is sufficiently simple, it is safe to apply it
1792    // *directly* to outer, rather than applying it to the distinctified key
1793    // (see below).
1794    //
1795    // As an example, consider the following two queries:
1796    //
1797    //     CREATE TABLE t (a int, b int);
1798    //     SELECT a, series FROM t, generate_series(1, t.b) series;
1799    //
1800    // The "simple" path for the `SELECT` yields
1801    //
1802    //     %0 =
1803    //     | Get t
1804    //     | FlatMap generate_series(1, #1)
1805    //
1806    // while the non-simple path yields:
1807    //
1808    //    %0 =
1809    //    | Get t
1810    //
1811    //    %1 =
1812    //    | Get t
1813    //    | Distinct group=(#1)
1814    //    | FlatMap generate_series(1, #0)
1815    //
1816    //    %2 =
1817    //    | LeftJoin %1 %2 (= #1 #2)
1818    //
1819    // There is a tradeoff here: the simple plan is stateless, but the non-
1820    // simple plan may do (much) less computation if there are only a few
1821    // distinct values of `t.b`.
1822    //
1823    // We apply a very simple heuristic here and take the simple path if `inner`
1824    // contains only maps, filters, projections, and calls to table functions.
1825    // The intuition is that straightforward usage of table functions should
1826    // take the simple path, while everything else should not. (In theory we
1827    // think this transformation is valid as long as `inner` does not contain a
1828    // Reduce, Distinct, or TopK node, but it is not always an optimization in
1829    // the general case.)
1830    //
1831    // TODO(benesch): this should all be handled by a proper optimizer, but
1832    // detecting the moment of decorrelation in the optimizer right now is too
1833    // hard.
1834    let mut is_simple = true;
1835    #[allow(deprecated)]
1836    inner.visit(0, &mut |expr, _| match expr {
1837        HirRelationExpr::Constant { .. }
1838        | HirRelationExpr::Project { .. }
1839        | HirRelationExpr::Map { .. }
1840        | HirRelationExpr::Filter { .. }
1841        | HirRelationExpr::CallTable { .. } => (),
1842        _ => is_simple = false,
1843    });
1844    if is_simple && !apply_requires_distinct_outer {
1845        let new_col_map = col_map.enter_scope(outer.arity() - col_map.len());
1846        return outer.let_in(id_gen, |id_gen, get_outer| {
1847            apply(id_gen, inner, get_outer, &new_col_map, cte_map, context)
1848        });
1849    }
1850
1851    // The key consists of the columns from the outer expression upon which the
1852    // inner relation depends. We discover these dependencies by walking the
1853    // inner relation expression and looking for column references whose level
1854    // escapes inner.
1855    //
1856    // At the end of this process, `key` contains the decorrelated position of
1857    // each outer column, according to the passed-in `col_map`, and
1858    // `new_col_map` maps each outer column to its new ordinal position in key.
1859    let mut outer_cols = BTreeSet::new();
1860    #[allow(deprecated)]
1861    inner.visit_columns(0, &mut |depth, col| {
1862        // Test if the column reference escapes the subquery.
1863        if col.level > depth {
1864            outer_cols.insert(ColumnRef {
1865                level: col.level - depth,
1866                column: col.column,
1867            });
1868        }
1869    });
1870    // Collect all the outer columns referenced by any CTE referenced by
1871    // the inner relation.
1872    #[allow(deprecated)]
1873    inner.visit(0, &mut |e, _| match e {
1874        HirRelationExpr::Get {
1875            id: mz_expr::Id::Local(id),
1876            ..
1877        } => {
1878            if let Some(cte_desc) = cte_map.get(id) {
1879                let cte_outer_arity = cte_desc.outer_relation.arity();
1880                outer_cols.extend(
1881                    col_map
1882                        .inner
1883                        .iter()
1884                        .filter(|(_, position)| **position < cte_outer_arity)
1885                        .map(|(c, _)| {
1886                            // `col_map` maps column references to column positions in
1887                            // `outer`'s projection.
1888                            // `outer_cols` is meant to contain the external column
1889                            // references in `inner`.
1890                            // Since `inner` defines a new scope, any column reference
1891                            // in `col_map` is one level deeper when seen from within
1892                            // `inner`, hence the +1.
1893                            ColumnRef {
1894                                level: c.level + 1,
1895                                column: c.column,
1896                            }
1897                        }),
1898                );
1899            }
1900        }
1901        HirRelationExpr::Let { id, .. } => {
1902            // Note: if ID uniqueness is not guaranteed, we can't use `visit` since
1903            // we would need to remove the old CTE with the same ID temporarily while
1904            // traversing the definition of the new CTE under the same ID.
1905            assert!(!cte_map.contains_key(id));
1906        }
1907        _ => {}
1908    });
1909    let mut new_col_map = BTreeMap::new();
1910    let mut key = vec![];
1911    for col in outer_cols {
1912        new_col_map.insert(col, key.len());
1913        key.push(col_map.get(&ColumnRef {
1914            // Note: `outer_cols` contains the external column references within `inner`.
1915            // We must compensate for `inner`'s scope when translating column references
1916            // as seen within `inner` to column references as seen from `outer`'s context,
1917            // hence the -1.
1918            level: col.level - 1,
1919            column: col.column,
1920        }));
1921    }
1922    let new_col_map = ColumnMap::new(new_col_map);
1923    outer.let_in(id_gen, |id_gen, get_outer| {
1924        let keyed_outer = if key.is_empty() {
1925            // Don't depend on outer at all if the branch is not correlated,
1926            // which yields vastly better query plans. Note that this is a bit
1927            // weird in that the branch will be computed even if outer has no
1928            // rows, whereas if it had been correlated it would not (and *could*
1929            // not) have been computed if outer had no rows, but the callers of
1930            // this function don't mind these somewhat-weird semantics.
1931            MirRelationExpr::constant(vec![vec![]], SqlRelationType::new(vec![]))
1932        } else {
1933            get_outer.clone().distinct_by(key.clone())
1934        };
1935        keyed_outer.let_in(id_gen, |id_gen, get_keyed_outer| {
1936            let oa = get_outer.arity();
1937            let branch = apply(
1938                id_gen,
1939                inner,
1940                get_keyed_outer,
1941                &new_col_map,
1942                cte_map,
1943                context,
1944            )?;
1945            let ba = branch.arity();
1946            let joined = MirRelationExpr::join(
1947                vec![get_outer.clone(), branch],
1948                key.iter()
1949                    .enumerate()
1950                    .map(|(i, &k)| vec![(0, k), (1, i)])
1951                    .collect(),
1952            )
1953            // throw away the right-hand copy of the key we just joined on
1954            .project((0..oa).chain((oa + key.len())..(oa + ba)).collect());
1955            Ok(joined)
1956        })
1957    })
1958}
1959
1960fn apply_scalar_subquery(
1961    id_gen: &mut mz_ore::id_gen::IdGen,
1962    outer: MirRelationExpr,
1963    col_map: &ColumnMap,
1964    cte_map: &mut CteMap,
1965    scalar_subquery: HirRelationExpr,
1966    apply_requires_distinct_outer: bool,
1967    context: &Context,
1968) -> Result<MirRelationExpr, PlanError> {
1969    branch(
1970        id_gen,
1971        outer,
1972        col_map,
1973        cte_map,
1974        scalar_subquery,
1975        apply_requires_distinct_outer,
1976        context,
1977        |id_gen, expr, get_inner, col_map, cte_map, context| {
1978            // compute for every row in get_inner
1979            let select = expr.applied_to(id_gen, get_inner.clone(), col_map, cte_map, context)?;
1980            let col_type = select.typ().column_types.into_last();
1981
1982            let inner_arity = get_inner.arity();
1983            // We must determine a count for each `get_inner` prefix,
1984            // and report an error if that count exceeds one.
1985            let guarded = select.let_in(id_gen, |_id_gen, get_select| {
1986                // Count for each `get_inner` prefix.
1987                let counts = get_select.clone().reduce(
1988                    (0..inner_arity).collect::<Vec<_>>(),
1989                    vec![mz_expr::AggregateExpr {
1990                        func: mz_expr::AggregateFunc::Count,
1991                        expr: MirScalarExpr::literal_true(),
1992                        distinct: false,
1993                    }],
1994                    None,
1995                );
1996
1997                let use_guard = context.config.enable_guard_subquery_tablefunc;
1998
1999                // Errors should result from counts > 1.
2000                let errors = if use_guard {
2001                    counts
2002                        .flat_map(
2003                            mz_expr::TableFunc::GuardSubquerySize {
2004                                column_type: col_type.clone().scalar_type,
2005                            },
2006                            vec![MirScalarExpr::column(inner_arity)],
2007                        )
2008                        .project(
2009                            (0..inner_arity)
2010                                .chain(Some(inner_arity + 1))
2011                                .collect::<Vec<_>>(),
2012                        )
2013                } else {
2014                    counts
2015                        .filter(vec![MirScalarExpr::column(inner_arity).call_binary(
2016                            MirScalarExpr::literal_ok(Datum::Int64(1), SqlScalarType::Int64),
2017                            func::Gt,
2018                        )])
2019                        .project((0..inner_arity).collect::<Vec<_>>())
2020                        .map_one(MirScalarExpr::literal(
2021                            Err(mz_expr::EvalError::MultipleRowsFromSubquery),
2022                            col_type.clone().scalar_type,
2023                        ))
2024                };
2025                // Return `get_select` and any errors added in.
2026                Ok::<_, PlanError>(get_select.union(errors))
2027            })?;
2028            // append Null to anything that didn't return any rows
2029            let default = vec![(Datum::Null, col_type.scalar_type)];
2030            get_inner.lookup(id_gen, guarded, default)
2031        },
2032    )
2033}
2034
2035fn apply_existential_subquery(
2036    id_gen: &mut mz_ore::id_gen::IdGen,
2037    outer: MirRelationExpr,
2038    col_map: &ColumnMap,
2039    cte_map: &mut CteMap,
2040    subquery_expr: HirRelationExpr,
2041    apply_requires_distinct_outer: bool,
2042    context: &Context,
2043) -> Result<MirRelationExpr, PlanError> {
2044    branch(
2045        id_gen,
2046        outer,
2047        col_map,
2048        cte_map,
2049        subquery_expr,
2050        apply_requires_distinct_outer,
2051        context,
2052        |id_gen, expr, get_inner, col_map, cte_map, context| {
2053            let exists = expr
2054                // compute for every row in get_inner
2055                .applied_to(id_gen, get_inner.clone(), col_map, cte_map, context)?
2056                // throw away actual values and just remember whether or not there were __any__ rows
2057                .distinct_by((0..get_inner.arity()).collect())
2058                // Append true to anything that returned any rows.
2059                .map(vec![MirScalarExpr::literal_true()]);
2060
2061            // append False to anything that didn't return any rows
2062            get_inner.lookup(id_gen, exists, vec![(Datum::False, SqlScalarType::Bool)])
2063        },
2064    )
2065}
2066
2067impl AggregateExpr {
2068    fn applied_to(
2069        self,
2070        id_gen: &mut mz_ore::id_gen::IdGen,
2071        col_map: &ColumnMap,
2072        cte_map: &mut CteMap,
2073        inner: &mut MirRelationExpr,
2074        context: &Context,
2075    ) -> Result<mz_expr::AggregateExpr, PlanError> {
2076        let AggregateExpr {
2077            func,
2078            expr,
2079            distinct,
2080        } = self;
2081
2082        Ok(mz_expr::AggregateExpr {
2083            func: func.into_expr(),
2084            expr: expr.applied_to(id_gen, col_map, cte_map, inner, &None, context)?,
2085            distinct,
2086        })
2087    }
2088}
2089
2090/// Attempts an efficient outer join, if `on` has equijoin structure.
2091///
2092/// Both `left` and `right` are decorrelated inputs.
2093///
2094/// The first `oa` columns correspond to an outer context: we should do the
2095/// outer join independently for each prefix. In the case that `on` contains
2096/// just some equality tests between columns of `left` and `right` and some
2097/// local predicates, we can employ a relatively simple plan.
2098///
2099/// The last `on_subquery_types.len()` columns correspond to results from
2100/// subqueries defined in the `on` clause - we treat those as theta-join
2101/// conditions that prohibit the use of the simple plan attempted here.
2102fn attempt_outer_equijoin(
2103    left: MirRelationExpr,
2104    right: MirRelationExpr,
2105    on: MirScalarExpr,
2106    on_subquery_types: Vec<SqlColumnType>,
2107    kind: JoinKind,
2108    oa: usize,
2109    id_gen: &mut mz_ore::id_gen::IdGen,
2110    context: &Context,
2111) -> Result<Option<MirRelationExpr>, PlanError> {
2112    // TODO(database-issues#6827): In theory, we can be smarter and also handle `on`
2113    // predicates that reference subqueries as long as these subqueries don't
2114    // reference `left` and `right` at the same time.
2115    //
2116    // TODO(database-issues#6828): This code can be improved as follows:
2117    //
2118    // 1. Move the `canonicalize_predicates(...)` call to `applied_to`.
2119    // 2. Use the canonicalized `on` predicate in the non-equijoin based
2120    //    lowering strategy.
2121    // 3. Move the `OnPredicates::new(...)` call to `applied_to`.
2122    // 4. Pass the classified `OnPredicates` as a parameter.
2123    // 5. Guard calls of this function with `on_predicates.is_equijoin()`.
2124    //
2125    // Steps (1 + 2) require further investigation because we might change the
2126    // error semantics in case the `on` predicate contains a literal error..
2127
2128    let l_type = left.typ();
2129    let r_type = right.typ();
2130    let la = l_type.column_types.len() - oa;
2131    let ra = r_type.column_types.len() - oa;
2132    let sa = on_subquery_types.len();
2133
2134    // The output type contains [outer, left, right, sa] attributes.
2135    let mut output_type = Vec::with_capacity(oa + la + ra + sa);
2136    output_type.extend(l_type.column_types);
2137    output_type.extend(r_type.column_types.into_iter().skip(oa));
2138    output_type.extend(on_subquery_types);
2139
2140    // Generally healthy to do, but specifically `USING` conditions sometimes
2141    // put an `AND true` at the end of the `ON` condition.
2142    //
2143    // TODO(aalexandrov): maybe we should already be doing this in `applied_to`.
2144    // However, in that case it's not clear that we won't see regressions if
2145    // `on` simplifies to a literal error.
2146    let mut on = vec![on];
2147    mz_expr::canonicalize::canonicalize_predicates(&mut on, &output_type);
2148
2149    // Form the left and right types without the outer attributes.
2150    output_type.drain(0..oa);
2151    let lt = output_type.drain(0..la).collect_vec();
2152    let rt = output_type.drain(0..ra).collect_vec();
2153    assert!(output_type.len() == sa);
2154
2155    let on_predicates = OnPredicates::new(oa, la, ra, sa, on.clone(), context);
2156    if !on_predicates.is_equijoin(context) {
2157        return Ok(None);
2158    }
2159
2160    // If we've gotten this far, we can do the clever thing.
2161    // We'll want to use left and right multiple times
2162    let result = left.let_in(id_gen, |id_gen, get_left| {
2163        right.let_in(id_gen, |id_gen, get_right| {
2164            // TODO: we know that we can re-use the arrangements of left and right
2165            // needed for the inner join with each of the conditional outer joins.
2166            // It is not clear whether we should hint that, or just let the planner
2167            // and optimizer run and see what happens.
2168
2169            // We'll want the inner join (minus repeated columns)
2170            let join = MirRelationExpr::join(
2171                vec![get_left.clone(), get_right.clone()],
2172                (0..oa).map(|i| vec![(0, i), (1, i)]).collect(),
2173            )
2174            // remove those columns from `right` repeating the first `oa` columns.
2175            .project(
2176                (0..(oa + la))
2177                    .chain((oa + la + oa)..(oa + la + oa + ra))
2178                    .collect(),
2179            )
2180            // apply the filter constraints here, to ensure nulls are not matched.
2181            .filter(on);
2182
2183            // We'll want to re-use the results of the join multiple times.
2184            join.let_in(id_gen, |id_gen, get_join| {
2185                let mut result = get_join.clone();
2186
2187                // A collection of keys present in both left and right collections.
2188                let join_keys = on_predicates.join_keys();
2189                let both_keys_arity = join_keys.len();
2190                let both_keys = get_join.restrict(join_keys).distinct();
2191
2192                // The plan is now to determine the left and right rows matched in the
2193                // inner join, subtract them from left and right respectively, pad what
2194                // remains with nulls, and fold them in to `result`.
2195
2196                both_keys.let_in(id_gen, |_id_gen, get_both| {
2197                    if let JoinKind::LeftOuter { .. } | JoinKind::FullOuter = kind {
2198                        // Rows in `left` matched in the inner equijoin. This is
2199                        // a semi-join between `left` and `both_keys`.
2200                        let left_present = MirRelationExpr::join_scalars(
2201                            vec![
2202                                get_left
2203                                    .clone()
2204                                    // Push local predicates.
2205                                    .filter(on_predicates.lhs()),
2206                                get_both.clone(),
2207                            ],
2208                            itertools::zip_eq(
2209                                on_predicates.eq_lhs(),
2210                                (0..both_keys_arity).map(|k| MirScalarExpr::column(oa + la + k)),
2211                            )
2212                            .map(|(l_key, b_key)| [l_key, b_key].to_vec())
2213                            .collect(),
2214                        )
2215                        .project((0..(oa + la)).collect());
2216
2217                        // Determine the types of nulls to use as filler.
2218                        let right_fill = rt
2219                            .into_iter()
2220                            .map(|typ| MirScalarExpr::literal_null(typ.scalar_type))
2221                            .collect();
2222
2223                        // Add to `result` absent elements, filled with typed nulls.
2224                        result = left_present
2225                            .negate()
2226                            .union(get_left.clone())
2227                            .map(right_fill)
2228                            .union(result);
2229                    }
2230
2231                    if let JoinKind::RightOuter | JoinKind::FullOuter = kind {
2232                        // Rows in `right` matched in the inner equijoin. This
2233                        // is a semi-join between `right` and `both_keys`.
2234                        let right_present = MirRelationExpr::join_scalars(
2235                            vec![
2236                                get_right
2237                                    .clone()
2238                                    // Push local predicates.
2239                                    .filter(on_predicates.rhs()),
2240                                get_both,
2241                            ],
2242                            itertools::zip_eq(
2243                                on_predicates.eq_rhs(),
2244                                (0..both_keys_arity).map(|k| MirScalarExpr::column(oa + ra + k)),
2245                            )
2246                            .map(|(r_key, b_key)| [r_key, b_key].to_vec())
2247                            .collect(),
2248                        )
2249                        .project((0..(oa + ra)).collect());
2250
2251                        // Determine the types of nulls to use as filler.
2252                        let left_fill = lt
2253                            .into_iter()
2254                            .map(|typ| MirScalarExpr::literal_null(typ.scalar_type))
2255                            .collect();
2256
2257                        // Add to `result` absent elements, prepended with typed nulls.
2258                        result = right_present
2259                            .negate()
2260                            .union(get_right.clone())
2261                            .map(left_fill)
2262                            // Permute left fill before right values.
2263                            .project(
2264                                itertools::chain!(
2265                                    0..oa,                 // Preserve `outer`.
2266                                    oa + ra..oa + la + ra, // Increment the next `la` cols by `ra`.
2267                                    oa..oa + ra            // Decrement the next `ra` cols by `la`.
2268                                )
2269                                .collect(),
2270                            )
2271                            .union(result)
2272                    }
2273
2274                    Ok::<_, PlanError>(result)
2275                })
2276            })
2277        })
2278    })?;
2279    Ok(Some(result))
2280}
2281
2282/// A struct that represents the predicates in the `on` clause in a form
2283/// suitable for efficient planning outer joins with equijoin predicates.
2284struct OnPredicates {
2285    /// A store for classified `ON` predicates.
2286    ///
2287    /// Predicates that reference a single side are adjusted to assume an
2288    /// `outer × <side>` schema.
2289    predicates: Vec<OnPredicate>,
2290    /// Number of outer context columns.
2291    oa: usize,
2292}
2293
2294impl OnPredicates {
2295    const I_OUT: usize = 0; // outer context input position
2296    const I_LHS: usize = 1; // lhs input position
2297    const I_RHS: usize = 2; // rhs input position
2298    const I_SUB: usize = 3; // on subqueries input position
2299
2300    /// Classify the predicates in the `on` clause of an outer join.
2301    ///
2302    /// The other parameters are arities of the input parts:
2303    ///
2304    /// - `oa` is the arity of the `outer` context.
2305    /// - `la` is the arity of the `left` input.
2306    /// - `ra` is the arity of the `right` input.
2307    /// - `sa` is the arity of the `on` subqueries.
2308    ///
2309    /// The constructor assumes that:
2310    ///
2311    /// 1. The `on` parameter will be applied on a result that has the following
2312    ///    schema `outer × left × right × on_subqueries`.
2313    /// 2. The `on` parameter is already adjusted to assume that schema.
2314    /// 3. The `on` parameter is obtained by canonicalizing the original `on:
2315    ///    MirScalarExpr` with `canonicalize_predicates`.
2316    fn new(
2317        oa: usize,
2318        la: usize,
2319        ra: usize,
2320        sa: usize,
2321        on: Vec<MirScalarExpr>,
2322        _context: &Context,
2323    ) -> Self {
2324        use mz_expr::BinaryFunc::Eq;
2325
2326        // Re-bind those locally for more compact pattern matching.
2327        const I_LHS: usize = OnPredicates::I_LHS;
2328        const I_RHS: usize = OnPredicates::I_RHS;
2329
2330        // Self parameters.
2331        let mut predicates = Vec::with_capacity(on.len());
2332
2333        // Helpers for populating `predicates`.
2334        let inner_join_mapper = mz_expr::JoinInputMapper::new_from_input_arities([oa, la, ra, sa]);
2335        let rhs_permutation = itertools::chain!(0..oa + la, oa..oa + ra).collect::<Vec<_>>();
2336        let lookup_inputs = |expr: &MirScalarExpr| -> Vec<usize> {
2337            inner_join_mapper
2338                .lookup_inputs(expr)
2339                .filter(|&i| i != Self::I_OUT)
2340                .collect()
2341        };
2342        let has_subquery_refs = |expr: &MirScalarExpr| -> bool {
2343            inner_join_mapper
2344                .lookup_inputs(expr)
2345                .any(|i| i == Self::I_SUB)
2346        };
2347
2348        // Iterate over `on` elements and populate `predicates`.
2349        for mut predicate in on {
2350            if predicate.might_error() {
2351                tracing::debug!(case = "thetajoin (error)", "OnPredicates::new");
2352                // Treat predicates that can produce a literal error as Theta.
2353                predicates.push(OnPredicate::Theta(predicate));
2354            } else if has_subquery_refs(&predicate) {
2355                tracing::debug!(case = "thetajoin (subquery)", "OnPredicates::new");
2356                // Treat predicates referencing an `on` subquery as Theta.
2357                predicates.push(OnPredicate::Theta(predicate));
2358            } else if let MirScalarExpr::CallBinary {
2359                func: Eq(_),
2360                expr1,
2361                expr2,
2362            } = &mut predicate
2363            {
2364                // Obtain the non-outer inputs referenced by each side.
2365                let inputs1 = lookup_inputs(expr1);
2366                let inputs2 = lookup_inputs(expr2);
2367
2368                match (&inputs1[..], &inputs2[..]) {
2369                    // Neither side references an input. This could be a
2370                    // constant expression or an expression that depends only on
2371                    // the outer context.
2372                    ([], []) => {
2373                        predicates.push(OnPredicate::Const(predicate));
2374                    }
2375                    // Both sides reference different inputs.
2376                    ([I_LHS], [I_RHS]) => {
2377                        let lhs = expr1.take();
2378                        let mut rhs = expr2.take();
2379                        rhs.permute(&rhs_permutation);
2380                        predicates.push(OnPredicate::Eq(lhs.clone(), rhs.clone()));
2381                        predicates.push(OnPredicate::LhsConsequence(lhs.call_is_null().not()));
2382                        predicates.push(OnPredicate::RhsConsequence(rhs.call_is_null().not()));
2383                    }
2384                    // Both sides reference different inputs (swapped).
2385                    ([I_RHS], [I_LHS]) => {
2386                        let lhs = expr2.take();
2387                        let mut rhs = expr1.take();
2388                        rhs.permute(&rhs_permutation);
2389                        predicates.push(OnPredicate::Eq(lhs.clone(), rhs.clone()));
2390                        predicates.push(OnPredicate::LhsConsequence(lhs.call_is_null().not()));
2391                        predicates.push(OnPredicate::RhsConsequence(rhs.call_is_null().not()));
2392                    }
2393                    // Both sides reference the left input or no input.
2394                    ([I_LHS], [I_LHS]) | ([I_LHS], []) | ([], [I_LHS]) => {
2395                        predicates.push(OnPredicate::Lhs(predicate));
2396                    }
2397                    // Both sides reference the right input or no input.
2398                    ([I_RHS], [I_RHS]) | ([I_RHS], []) | ([], [I_RHS]) => {
2399                        predicate.permute(&rhs_permutation);
2400                        predicates.push(OnPredicate::Rhs(predicate));
2401                    }
2402                    // At least one side references more than one input.
2403                    _ => {
2404                        tracing::debug!(case = "thetajoin (eq)", "OnPredicates::new");
2405                        predicates.push(OnPredicate::Theta(predicate));
2406                    }
2407                }
2408            } else {
2409                // Obtain the non-outer inputs referenced by this predicate.
2410                let inputs = lookup_inputs(&predicate);
2411
2412                match &inputs[..] {
2413                    // The predicate references no inputs. This could be a
2414                    // constant expression or an expression that depends only on
2415                    // the outer context.
2416                    [] => {
2417                        predicates.push(OnPredicate::Const(predicate));
2418                    }
2419                    // The predicate references only the left input.
2420                    [I_LHS] => {
2421                        predicates.push(OnPredicate::Lhs(predicate));
2422                    }
2423                    // The predicate references only the right input.
2424                    [I_RHS] => {
2425                        predicate.permute(&rhs_permutation);
2426                        predicates.push(OnPredicate::Rhs(predicate));
2427                    }
2428                    // The predicate references both inputs.
2429                    _ => {
2430                        tracing::debug!(case = "thetajoin (non-eq)", "OnPredicates::new");
2431                        predicates.push(OnPredicate::Theta(predicate));
2432                    }
2433                }
2434            }
2435        }
2436
2437        Self { predicates, oa }
2438    }
2439
2440    /// Check if the predicates can be lowered with an equijoin-based strategy.
2441    fn is_equijoin(&self, context: &Context) -> bool {
2442        // Count each `OnPredicate` variant in `self.predicates`.
2443        let (const_cnt, lhs_cnt, rhs_cnt, eq_cnt, eq_cols, theta_cnt) =
2444            self.predicates.iter().fold(
2445                (0, 0, 0, 0, 0, 0),
2446                |(const_cnt, lhs_cnt, rhs_cnt, eq_cnt, eq_cols, theta_cnt), p| {
2447                    (
2448                        const_cnt + usize::from(matches!(p, OnPredicate::Const(..))),
2449                        lhs_cnt + usize::from(matches!(p, OnPredicate::Lhs(..))),
2450                        rhs_cnt + usize::from(matches!(p, OnPredicate::Rhs(..))),
2451                        eq_cnt + usize::from(matches!(p, OnPredicate::Eq(..))),
2452                        eq_cols + usize::from(matches!(p, OnPredicate::Eq(lhs, rhs) if lhs.is_column() && rhs.is_column())),
2453                        theta_cnt + usize::from(matches!(p, OnPredicate::Theta(..))),
2454                    )
2455                },
2456            );
2457
2458        let is_equijion = if context.config.enable_new_outer_join_lowering {
2459            // New classifier.
2460            eq_cnt > 0 && theta_cnt == 0
2461        } else {
2462            // Old classifier.
2463            eq_cnt > 0 && eq_cnt == eq_cols && theta_cnt + const_cnt + lhs_cnt + rhs_cnt == 0
2464        };
2465
2466        // Log an entry only if this is an equijoin according to the new classifier.
2467        if eq_cnt > 0 && theta_cnt == 0 {
2468            tracing::debug!(
2469                const_cnt,
2470                lhs_cnt,
2471                rhs_cnt,
2472                eq_cnt,
2473                eq_cols,
2474                theta_cnt,
2475                "OnPredicates::is_equijoin"
2476            );
2477        }
2478
2479        is_equijion
2480    }
2481
2482    /// Return an [`MirRelationExpr`] list that represents the keys for the
2483    /// equijoin. The list will contain the outer columns as a prefix.
2484    fn join_keys(&self) -> JoinKeys {
2485        // We could return either the `lhs` or the `rhs` of the keys used to
2486        // form the inner join as they are equated by the join condition.
2487        let join_keys = self.eq_lhs().collect::<Vec<_>>();
2488
2489        if join_keys.iter().all(|k| k.is_column()) {
2490            tracing::debug!(case = "outputs", "OnPredicates::join_keys");
2491            JoinKeys::Outputs(join_keys.iter().flat_map(|k| k.as_column()).collect())
2492        } else {
2493            tracing::debug!(case = "scalars", "OnPredicates::join_keys");
2494            JoinKeys::Scalars(join_keys)
2495        }
2496    }
2497
2498    /// Return an iterator over the left-hand sides of all [`OnPredicate::Eq`]
2499    /// conditions in the predicates list.
2500    ///
2501    /// The iterator will start with column references to the outer columns as a
2502    /// prefix.
2503    fn eq_lhs(&self) -> impl Iterator<Item = MirScalarExpr> + '_ {
2504        itertools::chain(
2505            (0..self.oa).map(MirScalarExpr::column),
2506            self.predicates.iter().filter_map(|e| match e {
2507                OnPredicate::Eq(lhs, _) => Some(lhs.clone()),
2508                _ => None,
2509            }),
2510        )
2511    }
2512
2513    /// Return an iterator over the right-hand sides of all [`OnPredicate::Eq`]
2514    /// conditions in the predicates list.
2515    ///
2516    /// The iterator will start with column references to the outer columns as a
2517    /// prefix.
2518    fn eq_rhs(&self) -> impl Iterator<Item = MirScalarExpr> + '_ {
2519        itertools::chain(
2520            (0..self.oa).map(MirScalarExpr::column),
2521            self.predicates.iter().filter_map(|e| match e {
2522                OnPredicate::Eq(_, rhs) => Some(rhs.clone()),
2523                _ => None,
2524            }),
2525        )
2526    }
2527
2528    /// Return an iterator over the [`OnPredicate::Lhs`], [`OnPredicate::LhsConsequence`] and
2529    /// [`OnPredicate::Const`] conditions in the predicates list.
2530    fn lhs(&self) -> impl Iterator<Item = MirScalarExpr> + '_ {
2531        self.predicates.iter().filter_map(|p| match p {
2532            // We treat Const predicates local to both inputs.
2533            OnPredicate::Const(p) => Some(p.clone()),
2534            OnPredicate::Lhs(p) => Some(p.clone()),
2535            OnPredicate::LhsConsequence(p) => Some(p.clone()),
2536            _ => None,
2537        })
2538    }
2539
2540    /// Return an iterator over the [`OnPredicate::Rhs`], [`OnPredicate::RhsConsequence`] and
2541    /// [`OnPredicate::Const`] conditions in the predicates list.
2542    fn rhs(&self) -> impl Iterator<Item = MirScalarExpr> + '_ {
2543        self.predicates.iter().filter_map(|p| match p {
2544            // We treat Const predicates local to both inputs.
2545            OnPredicate::Const(p) => Some(p.clone()),
2546            OnPredicate::Rhs(p) => Some(p.clone()),
2547            OnPredicate::RhsConsequence(p) => Some(p.clone()),
2548            _ => None,
2549        })
2550    }
2551}
2552
2553enum OnPredicate {
2554    // A predicate that is either constant or references only outer columns.
2555    Const(MirScalarExpr),
2556    // A local predicate on the left-hand side of the join, i.e., it references only the left input
2557    // and possibly outer columns.
2558    //
2559    // This is one of the original predicates from the ON clause.
2560    //
2561    // One _must_ apply this predicate.
2562    Lhs(MirScalarExpr),
2563    // A local predicate on the left-hand side of the join, i.e., it references only the left input
2564    // and possibly outer columns.
2565    //
2566    // This is not one of the original predicates from the ON clause, but is just a consequence
2567    // of an original predicate in the ON clause, where the original predicate references both
2568    // inputs, but the consequence references only the left input.
2569    //
2570    // For example, the original predicate `input1.x = input2.a` has the consequence
2571    // `input1.x IS NOT NULL`. Applying such a consequence before the input is fed into the join
2572    // prevents null skew, and also makes more CSE opportunities available when the left input's key
2573    // doesn't have a NOT NULL constraint, saving us an arrangement.
2574    //
2575    // Applying the predicate is optional, because the original predicate will be applied anyway.
2576    LhsConsequence(MirScalarExpr),
2577    // A local predicate on the right-hand side of the join.
2578    //
2579    // This is one of the original predicates from the ON clause.
2580    //
2581    // One _must_ apply this predicate.
2582    Rhs(MirScalarExpr),
2583    // A consequence of an original ON predicate, see above.
2584    RhsConsequence(MirScalarExpr),
2585    // An equality predicate between the two sides.
2586    Eq(MirScalarExpr, MirScalarExpr),
2587    // a non-equality predicate between the two sides.
2588    #[allow(dead_code)]
2589    Theta(MirScalarExpr),
2590}
2591
2592/// A set of join keys referencing an input.
2593///
2594/// This is used in the [`MirRelationExpr::Join`] lowering code in order to
2595/// avoid changes (and thereby possible regressions) in plans that have equijoin
2596/// predicates consisting only of column refs.
2597///
2598/// If we were running `CanonicalizeMfp` as part of `NormalizeOps` we might be
2599/// able to get rid of this code, but as it stands `Map` simplification seems
2600/// more cumbersome than `Project` simplification, so do this just to be sure.
2601enum JoinKeys {
2602    // A predicate that is either constant or references only outer columns.
2603    Outputs(Vec<usize>),
2604    // A local predicate on the left-hand side of the join.
2605    Scalars(Vec<MirScalarExpr>),
2606}
2607
2608impl JoinKeys {
2609    fn len(&self) -> usize {
2610        match self {
2611            JoinKeys::Outputs(outputs) => outputs.len(),
2612            JoinKeys::Scalars(scalars) => scalars.len(),
2613        }
2614    }
2615}
2616
2617/// Extension methods for [`MirRelationExpr`] required in the HIR ⇒ MIR lowering
2618/// code.
2619trait LoweringExt {
2620    /// See [`MirRelationExpr::restrict`].
2621    fn restrict(self, join_keys: JoinKeys) -> Self;
2622}
2623
2624impl LoweringExt for MirRelationExpr {
2625    /// Restrict the set of columns of an input to the sequence of [`JoinKeys`].
2626    fn restrict(self, join_keys: JoinKeys) -> Self {
2627        let num_keys = join_keys.len();
2628        match join_keys {
2629            JoinKeys::Outputs(outputs) => self.project(outputs),
2630            JoinKeys::Scalars(scalars) => {
2631                let input_arity = self.arity();
2632                let outputs = (input_arity..input_arity + num_keys).collect();
2633                self.map(scalars).project(outputs)
2634            }
2635        }
2636    }
2637}