mz_sql/plan/lowering.rs
1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Lowering is the process of transforming a `HirRelationExpr`
11//! into a `MirRelationExpr`.
12//!
13//! The most crucial part of lowering is decorrelation; i.e.: rewriting a
14//! `HirScalarExpr` that may contain subqueries (e.g. `SELECT` or `EXISTS`)
15//! with instances of `MirScalarExpr` that contain none of these.
16//!
17//! Informally, a subquery should be viewed as a query that is executed in
18//! the context of some outer relation, for each row of that relation. The
19//! subqueries often contain references to the columns of the outer
20//! relation.
21//!
22//! The transformation we perform maintains an `outer` relation and then
23//! traverses the relation expression that may contain references to those
24//! outer columns. As subqueries are discovered, the current relation
25//! expression is recast as the outer expression until such a point as the
26//! scalar expression's evaluation can be determined and appended to each
27//! row of the previously outer relation.
28//!
29//! It is important that the outer columns (the initial columns) act as keys
30//! for all nested computation. When counts or other aggregations are
31//! performed, they should include not only the indicated keys but also all
32//! of the outer columns.
33//!
34//! The decorrelation transformation is initialized with an empty outer
35//! relation, but it seems entirely appropriate to decorrelate queries that
36//! contain "holes" from prepared statements, as if the query was a subquery
37//! against a relation containing the assignments of values to those holes.
38
39use std::collections::{BTreeMap, BTreeSet};
40use std::iter::repeat;
41
42use itertools::Itertools;
43use mz_expr::visit::Visit;
44use mz_expr::{AccessStrategy, AggregateFunc, MirRelationExpr, MirScalarExpr, func};
45use mz_ore::collections::CollectionExt;
46use mz_ore::stack::maybe_grow;
47use mz_repr::*;
48
49use crate::optimizer_metrics::OptimizerMetrics;
50use crate::plan::hir::{
51 AggregateExpr, ColumnOrder, ColumnRef, HirRelationExpr, HirScalarExpr, JoinKind, WindowExprType,
52};
53use crate::plan::{PlanError, transform_hir};
54use crate::session::vars::SystemVars;
55
56mod variadic_left;
57
58/// Maps a leveled column reference to a specific column.
59///
60/// Leveled column references are nested, so that larger levels are
61/// found early in a record and level zero is found at the end.
62///
63/// The column map only stores references for levels greater than zero,
64/// and column references at level zero simply start at the first column
65/// after all prior references.
66#[derive(Debug, Clone)]
67struct ColumnMap {
68 inner: BTreeMap<ColumnRef, usize>,
69}
70
71impl ColumnMap {
72 fn empty() -> ColumnMap {
73 Self::new(BTreeMap::new())
74 }
75
76 fn new(inner: BTreeMap<ColumnRef, usize>) -> ColumnMap {
77 ColumnMap { inner }
78 }
79
80 fn get(&self, col_ref: &ColumnRef) -> usize {
81 if col_ref.level == 0 {
82 self.inner.len() + col_ref.column
83 } else {
84 self.inner[col_ref]
85 }
86 }
87
88 fn len(&self) -> usize {
89 self.inner.len()
90 }
91
92 /// Updates references in the `ColumnMap` for use in a nested scope. The
93 /// provided `arity` must specify the arity of the current scope.
94 fn enter_scope(&self, arity: usize) -> ColumnMap {
95 // From the perspective of the nested scope, all existing column
96 // references will be one level greater.
97 let existing = self
98 .inner
99 .clone()
100 .into_iter()
101 .update(|(col, _i)| col.level += 1);
102
103 // All columns in the current scope become explicit entries in the
104 // immediate parent scope.
105 let new = (0..arity).map(|i| {
106 (
107 ColumnRef {
108 level: 1,
109 column: i,
110 },
111 self.len() + i,
112 )
113 });
114
115 ColumnMap::new(existing.chain(new).collect())
116 }
117}
118
119/// Map with the CTEs currently in scope.
120type CteMap = BTreeMap<mz_expr::LocalId, CteDesc>;
121
122/// Information about needed when finding a reference to a CTE in scope.
123#[derive(Clone)]
124struct CteDesc {
125 /// The new ID assigned to the lowered version of the CTE, which may not match
126 /// the ID of the input CTE.
127 new_id: mz_expr::LocalId,
128 /// The relation type of the CTE including the columns from the outer
129 /// context at the beginning.
130 relation_type: SqlRelationType,
131 /// The outer relation the CTE was applied to.
132 outer_relation: MirRelationExpr,
133}
134
135#[derive(Debug)]
136pub struct Config {
137 /// Enable outer join lowering implemented in database-issues#6747.
138 pub enable_new_outer_join_lowering: bool,
139 /// Enable outer join lowering implemented in database-issues#7561.
140 pub enable_variadic_left_join_lowering: bool,
141 pub enable_guard_subquery_tablefunc: bool,
142}
143
144impl From<&SystemVars> for Config {
145 fn from(vars: &SystemVars) -> Self {
146 Self {
147 enable_new_outer_join_lowering: vars.enable_new_outer_join_lowering(),
148 enable_variadic_left_join_lowering: vars.enable_variadic_left_join_lowering(),
149 enable_guard_subquery_tablefunc: vars.enable_guard_subquery_tablefunc(),
150 }
151 }
152}
153
154/// Context passed to the lowering. This is wired to most parts of the lowering.
155pub(crate) struct Context<'a> {
156 /// Feature flags affecting the behavior of lowering.
157 pub config: &'a Config,
158 /// Optional, because some callers don't have an `OptimizerMetrics` handy. When it's None, we
159 /// simply don't write metrics.
160 pub metrics: Option<&'a OptimizerMetrics>,
161}
162
163impl HirRelationExpr {
164 /// Rewrite `self` into a `MirRelationExpr`.
165 /// This requires rewriting all correlated subqueries (nested `HirRelationExpr`s) into flat queries
166 #[mz_ore::instrument(target = "optimizer", level = "trace", name = "hir_to_mir")]
167 pub fn lower<C: Into<Config>>(
168 self,
169 config: C,
170 metrics: Option<&OptimizerMetrics>,
171 ) -> Result<MirRelationExpr, PlanError> {
172 let context = Context {
173 config: &config.into(),
174 metrics,
175 };
176 let result = match self {
177 // We directly rewrite a Constant into the corresponding `MirRelationExpr::Constant`
178 // to ensure that the downstream optimizer can easily bypass most
179 // irrelevant optimizations (e.g. reduce folding) for this expression
180 // without having to re-learn the fact that it is just a constant,
181 // as it would if the constant were wrapped in a Let-Get pair.
182 HirRelationExpr::Constant { rows, typ } => {
183 let rows: Vec<_> = rows.into_iter().map(|row| (row, Diff::ONE)).collect();
184 MirRelationExpr::Constant {
185 rows: Ok(rows),
186 typ,
187 }
188 }
189 mut other => {
190 let mut id_gen = mz_ore::id_gen::IdGen::default();
191 transform_hir::split_subquery_predicates(&mut other)?;
192 transform_hir::try_simplify_quantified_comparisons(&mut other)?;
193 transform_hir::fuse_window_functions(&mut other, &context)?;
194 MirRelationExpr::constant(vec![vec![]], SqlRelationType::new(vec![])).let_in(
195 &mut id_gen,
196 |id_gen, get_outer| {
197 other.applied_to(
198 id_gen,
199 get_outer,
200 &ColumnMap::empty(),
201 &mut CteMap::new(),
202 &context,
203 )
204 },
205 )?
206 }
207 };
208
209 mz_repr::explain::trace_plan(&result);
210
211 Ok(result)
212 }
213
214 /// Return a `MirRelationExpr` which evaluates `self` once for each row of `get_outer`.
215 ///
216 /// For uncorrelated `self`, this should be the cross-product between `get_outer` and `self`.
217 /// When `self` references columns of `get_outer`, much more work needs to occur.
218 ///
219 /// The `col_map` argument contains mappings to some of the columns of `get_outer`, though
220 /// perhaps not all of them. It should be used as the basis of resolving column references,
221 /// but care must be taken when adding new columns that `get_outer.arity()` is where they
222 /// will start, rather than any function of `col_map`.
223 ///
224 /// The `get_outer` expression should be a `Get` with no duplicate rows, describing the distinct
225 /// assignment of values to outer rows.
226 fn applied_to(
227 self,
228 id_gen: &mut mz_ore::id_gen::IdGen,
229 get_outer: MirRelationExpr,
230 col_map: &ColumnMap,
231 cte_map: &mut CteMap,
232 context: &Context,
233 ) -> Result<MirRelationExpr, PlanError> {
234 maybe_grow(|| {
235 use MirRelationExpr as SR;
236
237 use HirRelationExpr::*;
238
239 if let MirRelationExpr::Get { .. } = &get_outer {
240 } else {
241 panic!(
242 "get_outer: expected a MirRelationExpr::Get, found\n{}",
243 get_outer.pretty(),
244 );
245 }
246 assert_eq!(col_map.len(), get_outer.arity());
247 Ok(match self {
248 Constant { rows, typ } => {
249 // Constant expressions are not correlated with `get_outer`, and should be cross-products.
250 get_outer.product(SR::Constant {
251 rows: Ok(rows.into_iter().map(|row| (row, Diff::ONE)).collect()),
252 typ,
253 })
254 }
255 Get { id, typ } => match id {
256 mz_expr::Id::Local(local_id) => {
257 let cte_desc = cte_map.get(&local_id).unwrap();
258 let get_cte = SR::Get {
259 id: mz_expr::Id::Local(cte_desc.new_id.clone()),
260 typ: cte_desc.relation_type.clone(),
261 access_strategy: AccessStrategy::UnknownOrLocal,
262 };
263 if get_outer == cte_desc.outer_relation {
264 // If the CTE was applied to the same exact relation, we can safely
265 // return a `Get` relation.
266 get_cte
267 } else {
268 // Otherwise, the new outer relation may contain more columns from some
269 // intermediate scope placed between the definition of the CTE and this
270 // reference of the CTE and/or more operations applied on top of the
271 // outer relation.
272 //
273 // An example of the latter is the following query:
274 //
275 // SELECT *
276 // FROM x,
277 // LATERAL(WITH a(m) as (SELECT max(y.a) FROM y WHERE y.a < x.a)
278 // SELECT (SELECT m FROM a) FROM y) b;
279 //
280 // When the CTE is lowered, the outer relation is `Get x`. But then,
281 // the reference of the CTE is applied to `Distinct(Join(Get x, Get y), x.*)`
282 // which has the same cardinality as `Get x`.
283 //
284 // In any case, `get_outer` is guaranteed to contain the columns of the
285 // outer relation the CTE was applied to at its prefix. Since, we must
286 // return a relation containing `get_outer`'s column at the beginning,
287 // we must build a join between `get_outer` and `get_cte` on their common
288 // columns.
289 let oa = get_outer.arity();
290 let cte_outer_columns = cte_desc.relation_type.arity() - typ.arity();
291 let equivalences = (0..cte_outer_columns)
292 .map(|pos| {
293 vec![
294 MirScalarExpr::column(pos),
295 MirScalarExpr::column(pos + oa),
296 ]
297 })
298 .collect();
299
300 // Project out the second copy of the common between `get_outer` and
301 // `cte_desc.outer_relation`.
302 let projection = (0..oa)
303 .chain(oa + cte_outer_columns..oa + cte_outer_columns + typ.arity())
304 .collect_vec();
305 SR::join_scalars(vec![get_outer, get_cte], equivalences)
306 .project(projection)
307 }
308 }
309 _ => {
310 // Get statements are only to external sources, and are not correlated with `get_outer`.
311 get_outer.product(SR::Get {
312 id,
313 typ,
314 access_strategy: AccessStrategy::UnknownOrLocal,
315 })
316 }
317 },
318 Let {
319 name: _,
320 id,
321 value,
322 body,
323 } => {
324 let value =
325 value.applied_to(id_gen, get_outer.clone(), col_map, cte_map, context)?;
326 value.let_in(id_gen, |id_gen, get_value| {
327 let (new_id, typ) = if let MirRelationExpr::Get {
328 id: mz_expr::Id::Local(id),
329 typ,
330 ..
331 } = get_value
332 {
333 (id, typ)
334 } else {
335 panic!(
336 "get_value: expected a MirRelationExpr::Get with local Id, found\n{}",
337 get_value.pretty(),
338 );
339 };
340 // Add the information about the CTE to the map and remove it when
341 // it goes out of scope.
342 let old_value = cte_map.insert(
343 id.clone(),
344 CteDesc {
345 new_id,
346 relation_type: typ,
347 outer_relation: get_outer.clone(),
348 },
349 );
350 let body = body.applied_to(id_gen, get_outer, col_map, cte_map, context);
351 if let Some(old_value) = old_value {
352 cte_map.insert(id, old_value);
353 } else {
354 cte_map.remove(&id);
355 }
356 body
357 })?
358 }
359 LetRec {
360 limit,
361 bindings,
362 body,
363 } => {
364 let num_bindings = bindings.len();
365
366 // We use the outer type with the HIR types to form MIR CTE types.
367 let outer_column_types = get_outer.typ().column_types;
368
369 // Rename and introduce all bindings.
370 let mut shadowed_bindings = Vec::with_capacity(num_bindings);
371 let mut mir_ids = Vec::with_capacity(num_bindings);
372 for (_name, id, _value, typ) in bindings.iter() {
373 let mir_id = mz_expr::LocalId::new(id_gen.allocate_id());
374 mir_ids.push(mir_id);
375 let shadowed = cte_map.insert(
376 id.clone(),
377 CteDesc {
378 new_id: mir_id,
379 relation_type: SqlRelationType::new(
380 outer_column_types
381 .iter()
382 .cloned()
383 .chain(typ.column_types.iter().cloned())
384 .collect::<Vec<_>>(),
385 ),
386 outer_relation: get_outer.clone(),
387 },
388 );
389 shadowed_bindings.push((*id, shadowed));
390 }
391
392 let mut mir_values = Vec::with_capacity(num_bindings);
393 for (_name, _id, value, _typ) in bindings.into_iter() {
394 mir_values.push(value.applied_to(
395 id_gen,
396 get_outer.clone(),
397 col_map,
398 cte_map,
399 context,
400 )?);
401 }
402
403 let mir_body = body.applied_to(id_gen, get_outer, col_map, cte_map, context)?;
404
405 // Remove our bindings and reinstate any shadowed bindings.
406 for (id, shadowed) in shadowed_bindings {
407 if let Some(shadowed) = shadowed {
408 cte_map.insert(id, shadowed);
409 } else {
410 cte_map.remove(&id);
411 }
412 }
413
414 MirRelationExpr::LetRec {
415 ids: mir_ids,
416 values: mir_values,
417 // Copy the limit to each binding.
418 limits: repeat(limit).take(num_bindings).collect(),
419 body: Box::new(mir_body),
420 }
421 }
422 Project { input, outputs } => {
423 // Projections should be applied to the decorrelated `inner`, and to its columns,
424 // which means rebasing `outputs` to start `get_outer.arity()` columns later.
425 let input =
426 input.applied_to(id_gen, get_outer.clone(), col_map, cte_map, context)?;
427 let outputs = (0..get_outer.arity())
428 .chain(outputs.into_iter().map(|i| get_outer.arity() + i))
429 .collect::<Vec<_>>();
430 input.project(outputs)
431 }
432 Map { input, mut scalars } => {
433 // Scalar expressions may contain correlated subqueries. We must be cautious!
434
435 // We lower scalars in chunks, and must keep track of the
436 // arity of the HIR fragments lowered so far.
437 let mut lowered_arity = input.arity();
438
439 let mut input =
440 input.applied_to(id_gen, get_outer, col_map, cte_map, context)?;
441
442 // Lower subqueries in maximally sized batches, such as no subquery in the current
443 // batch depends on columns from the same batch.
444 // Note that subqueries in this projection may reference columns added by this
445 // Map operator, so we need to ensure these columns exist before lowering the
446 // subquery.
447 while !scalars.is_empty() {
448 let end_idx = scalars
449 .iter_mut()
450 .position(|s| {
451 let mut requires_nonexistent_column = false;
452 #[allow(deprecated)]
453 s.visit_columns(0, &mut |depth, col| {
454 if col.level == depth {
455 requires_nonexistent_column |= col.column >= lowered_arity
456 }
457 });
458 requires_nonexistent_column
459 })
460 .unwrap_or(scalars.len());
461 assert!(
462 end_idx > 0,
463 "a Map expression references itself or a later column; lowered_arity: {}, expressions: {:?}",
464 lowered_arity,
465 scalars
466 );
467
468 lowered_arity = lowered_arity + end_idx;
469 let scalars = scalars.drain(0..end_idx).collect_vec();
470
471 let old_arity = input.arity();
472 let (with_subqueries, subquery_map) = HirScalarExpr::lower_subqueries(
473 &scalars, id_gen, col_map, cte_map, input, context,
474 )?;
475 input = with_subqueries;
476
477 // We will proceed sequentially through the scalar expressions, for each transforming
478 // the decorrelated `input` into a relation with potentially more columns capable of
479 // addressing the needs of the scalar expression.
480 // Having done so, we add the scalar value of interest and trim off any other newly
481 // added columns.
482 //
483 // The sequential traversal is present as expressions are allowed to depend on the
484 // values of prior expressions.
485 let mut scalar_columns = Vec::new();
486 for scalar in scalars {
487 let scalar = scalar.applied_to(
488 id_gen,
489 col_map,
490 cte_map,
491 &mut input,
492 &Some(&subquery_map),
493 context,
494 )?;
495 input = input.map_one(scalar);
496 scalar_columns.push(input.arity() - 1);
497 }
498
499 // Discard any new columns added by the lowering of the scalar expressions
500 input = input.project((0..old_arity).chain(scalar_columns).collect());
501 }
502
503 input
504 }
505 CallTable { func, exprs } => {
506 // FlatMap expressions may contain correlated subqueries. Unlike Map they are not
507 // allowed to refer to the results of previous expressions, and we have a simpler
508 // implementation that appends all relevant columns first, then applies the flatmap
509 // operator to the result, then strips off any columns introduce by subqueries.
510
511 let mut input = get_outer;
512 let old_arity = input.arity();
513
514 let exprs = exprs
515 .into_iter()
516 .map(|e| e.applied_to(id_gen, col_map, cte_map, &mut input, &None, context))
517 .collect::<Result<Vec<_>, _>>()?;
518
519 let new_arity = input.arity();
520 let output_arity = func.output_arity();
521 input = input.flat_map(func, exprs);
522 if old_arity != new_arity {
523 // this means we added some columns to handle subqueries, and now we need to get rid of them
524 input = input.project(
525 (0..old_arity)
526 .chain(new_arity..new_arity + output_arity)
527 .collect(),
528 );
529 }
530 input
531 }
532 Filter { input, predicates } => {
533 // Filter expressions may contain correlated subqueries.
534 // We extend `get_outer` with sufficient values to determine the value of the predicate,
535 // then filter the results, then strip off any columns that were added for this purpose.
536 let mut input =
537 input.applied_to(id_gen, get_outer, col_map, cte_map, context)?;
538 for predicate in predicates {
539 let old_arity = input.arity();
540 let predicate = predicate
541 .applied_to(id_gen, col_map, cte_map, &mut input, &None, context)?;
542 let new_arity = input.arity();
543 input = input.filter(vec![predicate]);
544 if old_arity != new_arity {
545 // this means we added some columns to handle subqueries, and now we need to get rid of them
546 input = input.project((0..old_arity).collect());
547 }
548 }
549 input
550 }
551 Join {
552 left,
553 right,
554 on,
555 kind,
556 } if right.is_correlated() => {
557 // A correlated join is a join in which the right expression has
558 // access to the columns in the left expression. It turns out
559 // this is *exactly* our branch operator, plus some additional
560 // null handling in the case of left joins. (Right and full
561 // lateral joins are not permitted.)
562 //
563 // As with normal joins, the `on` predicate may be correlated,
564 // and we treat it as a filter that follows the branch.
565
566 assert!(kind.can_be_correlated());
567
568 let left = left.applied_to(id_gen, get_outer, col_map, cte_map, context)?;
569 left.let_in(id_gen, |id_gen, get_left| {
570 let apply_requires_distinct_outer = false;
571 let mut join = branch(
572 id_gen,
573 get_left.clone(),
574 col_map,
575 cte_map,
576 *right,
577 apply_requires_distinct_outer,
578 context,
579 |id_gen, right, get_left, col_map, cte_map, context| {
580 right.applied_to(id_gen, get_left, col_map, cte_map, context)
581 },
582 )?;
583
584 // Plan the `on` predicate.
585 let old_arity = join.arity();
586 let on =
587 on.applied_to(id_gen, col_map, cte_map, &mut join, &None, context)?;
588 join = join.filter(vec![on]);
589 let new_arity = join.arity();
590 if old_arity != new_arity {
591 // This means we added some columns to handle
592 // subqueries, and now we need to get rid of them.
593 join = join.project((0..old_arity).collect());
594 }
595
596 // If a left join, reintroduce any rows from the left that
597 // are missing, with nulls filled in for the right columns.
598 if let JoinKind::LeftOuter { .. } = kind {
599 let default = join
600 .typ()
601 .column_types
602 .into_iter()
603 .skip(get_left.arity())
604 .map(|typ| (Datum::Null, typ.scalar_type))
605 .collect();
606 get_left.lookup(id_gen, join, default)
607 } else {
608 Ok::<_, PlanError>(join)
609 }
610 })?
611 }
612 Join {
613 left,
614 right,
615 on,
616 kind,
617 } => {
618 if context.config.enable_variadic_left_join_lowering {
619 // Attempt to extract a stack of left joins.
620 if let JoinKind::LeftOuter = kind {
621 let mut rights = vec![(&*right, &on)];
622 let mut left_test = &left;
623 while let Join {
624 left,
625 right,
626 on,
627 kind: JoinKind::LeftOuter,
628 } = &**left_test
629 {
630 rights.push((&**right, on));
631 left_test = left;
632 }
633 if rights.len() > 1 {
634 // Defensively clone `cte_map` as it may be mutated.
635 let cte_map_clone = cte_map.clone();
636 if let Ok(Some(magic)) = variadic_left::attempt_left_join_magic(
637 left_test,
638 rights,
639 id_gen,
640 get_outer.clone(),
641 col_map,
642 cte_map,
643 context,
644 ) {
645 return Ok(magic);
646 } else {
647 cte_map.clone_from(&cte_map_clone);
648 }
649 }
650 }
651 }
652
653 // Both join expressions should be decorrelated, and then joined by their
654 // leading columns to form only those pairs corresponding to the same row
655 // of `get_outer`.
656 //
657 // The `on` predicate may contain correlated subqueries, and we treat it
658 // as though it was a filter, with the caveat that we also translate outer
659 // joins in this step. The post-filtration results need to be considered
660 // against the records present in the left and right (decorrelated) inputs,
661 // depending on the type of join.
662 let oa = get_outer.arity();
663 let left =
664 left.applied_to(id_gen, get_outer.clone(), col_map, cte_map, context)?;
665 let lt = left.typ().column_types.into_iter().skip(oa).collect_vec();
666 let la = lt.len();
667 left.let_in(id_gen, |id_gen, get_left| {
668 let right_col_map = col_map.enter_scope(0);
669 let right = right.applied_to(
670 id_gen,
671 get_outer.clone(),
672 &right_col_map,
673 cte_map,
674 context,
675 )?;
676 let rt = right.typ().column_types.into_iter().skip(oa).collect_vec();
677 let ra = rt.len();
678 right.let_in(id_gen, |id_gen, get_right| {
679 let mut product = SR::join(
680 vec![get_left.clone(), get_right.clone()],
681 (0..oa).map(|i| vec![(0, i), (1, i)]).collect(),
682 )
683 // Project away the repeated copy of get_outer's columns.
684 .project(
685 (0..(oa + la))
686 .chain((oa + la + oa)..(oa + la + oa + ra))
687 .collect(),
688 );
689
690 // Decorrelate and lower the `on` clause.
691 let on = on.applied_to(
692 id_gen,
693 col_map,
694 cte_map,
695 &mut product,
696 &None,
697 context,
698 )?;
699 // Collect the types of all subqueries appearing in
700 // the `on` clause. The subquery results were
701 // appended to `product` in the `on.applied_to(...)`
702 // call above.
703 let on_subquery_types = product
704 .typ()
705 .column_types
706 .drain(oa + la + ra..)
707 .collect_vec();
708 // Remember if `on` had any subqueries.
709 let on_has_subqueries = !on_subquery_types.is_empty();
710
711 // Attempt an efficient equijoin implementation, in which outer joins are
712 // more efficiently rendered than in general. This can return `None` if
713 // such a plan is not possible, for example if `on` does not describe an
714 // equijoin between columns of `left` and `right`.
715 if kind != JoinKind::Inner {
716 if let Some(joined) = attempt_outer_equijoin(
717 get_left.clone(),
718 get_right.clone(),
719 on.clone(),
720 on_subquery_types,
721 kind.clone(),
722 oa,
723 id_gen,
724 context,
725 )? {
726 if let Some(metrics) = context.metrics {
727 metrics.inc_outer_join_lowering("equi");
728 }
729 return Ok(joined);
730 }
731 }
732
733 // Otherwise, perform a more general join.
734 if let Some(metrics) = context.metrics {
735 metrics.inc_outer_join_lowering("general");
736 }
737 let mut join = product.filter(vec![on]);
738 if on_has_subqueries {
739 // This means that `on.applied_to(...)` appended
740 // some columns to handle subqueries, and now we
741 // need to get rid of them.
742 join = join.project((0..oa + la + ra).collect());
743 }
744 join.let_in(id_gen, |id_gen, get_join| {
745 let mut result = get_join.clone();
746 if let JoinKind::LeftOuter { .. } | JoinKind::FullOuter { .. } =
747 kind
748 {
749 let left_outer = get_left.clone().anti_lookup::<PlanError>(
750 id_gen,
751 get_join.clone(),
752 rt.into_iter()
753 .map(|typ| (Datum::Null, typ.scalar_type))
754 .collect(),
755 )?;
756 result = result.union(left_outer);
757 }
758 if let JoinKind::RightOuter | JoinKind::FullOuter = kind {
759 let right_outer = get_right
760 .clone()
761 .anti_lookup::<PlanError>(
762 id_gen,
763 get_join
764 // need to swap left and right to make the anti_lookup work
765 .project(
766 (0..oa)
767 .chain((oa + la)..(oa + la + ra))
768 .chain((oa)..(oa + la))
769 .collect(),
770 ),
771 lt.into_iter()
772 .map(|typ| (Datum::Null, typ.scalar_type))
773 .collect(),
774 )?
775 // swap left and right back again
776 .project(
777 (0..oa)
778 .chain((oa + ra)..(oa + ra + la))
779 .chain((oa)..(oa + ra))
780 .collect(),
781 );
782 result = result.union(right_outer);
783 }
784 Ok::<MirRelationExpr, PlanError>(result)
785 })
786 })
787 })?
788 }
789 Union { base, inputs } => {
790 // Union is uncomplicated.
791 SR::Union {
792 base: Box::new(base.applied_to(
793 id_gen,
794 get_outer.clone(),
795 col_map,
796 cte_map,
797 context,
798 )?),
799 inputs: inputs
800 .into_iter()
801 .map(|input| {
802 input.applied_to(
803 id_gen,
804 get_outer.clone(),
805 col_map,
806 cte_map,
807 context,
808 )
809 })
810 .collect::<Result<Vec<_>, _>>()?,
811 }
812 }
813 Reduce {
814 input,
815 group_key,
816 aggregates,
817 expected_group_size,
818 } => {
819 // Reduce may contain expressions with correlated subqueries.
820 // In addition, here an empty reduction key signifies that we need to supply default values
821 // in the case that there are no results (as in a SQL aggregation without an explicit GROUP BY).
822 let mut input =
823 input.applied_to(id_gen, get_outer.clone(), col_map, cte_map, context)?;
824 let applied_group_key = (0..get_outer.arity())
825 .chain(group_key.iter().map(|i| get_outer.arity() + i))
826 .collect();
827 let applied_aggregates = aggregates
828 .into_iter()
829 .map(|aggregate| {
830 aggregate.applied_to(id_gen, col_map, cte_map, &mut input, context)
831 })
832 .collect::<Result<Vec<_>, _>>()?;
833 let input_type = input.typ();
834 let default = applied_aggregates
835 .iter()
836 .map(|agg| {
837 (
838 agg.func.default(),
839 agg.typ(&input_type.column_types).scalar_type,
840 )
841 })
842 .collect();
843 // NOTE we don't need to remove any extra columns from aggregate.applied_to above because the reduce will do that anyway
844 let mut reduced =
845 input.reduce(applied_group_key, applied_aggregates, expected_group_size);
846
847 // Introduce default values in the case the group key is empty.
848 if group_key.is_empty() {
849 reduced = get_outer.lookup::<PlanError>(id_gen, reduced, default)?;
850 }
851 reduced
852 }
853 Distinct { input } => {
854 // Distinct is uncomplicated.
855 input
856 .applied_to(id_gen, get_outer, col_map, cte_map, context)?
857 .distinct()
858 }
859 TopK {
860 input,
861 group_key,
862 order_key,
863 limit,
864 offset,
865 expected_group_size,
866 } => {
867 // TopK is uncomplicated, except that we must group by the columns of `get_outer` as well.
868 let mut input =
869 input.applied_to(id_gen, get_outer.clone(), col_map, cte_map, context)?;
870 let mut applied_group_key: Vec<_> = (0..get_outer.arity())
871 .chain(group_key.iter().map(|i| get_outer.arity() + i))
872 .collect();
873 let applied_order_key = order_key
874 .iter()
875 .map(|column_order| ColumnOrder {
876 column: column_order.column + get_outer.arity(),
877 desc: column_order.desc,
878 nulls_last: column_order.nulls_last,
879 })
880 .collect();
881
882 let old_arity = input.arity();
883
884 // Lower `limit`, which may introduce new columns if is a correlated subquery.
885 let mut limit_mir = None;
886 if let Some(limit) = limit {
887 limit_mir = Some(
888 limit
889 .applied_to(id_gen, col_map, cte_map, &mut input, &None, context)?,
890 );
891 }
892
893 let new_arity = input.arity();
894 // Extend the key to contain any new columns.
895 applied_group_key.extend(old_arity..new_arity);
896
897 let offset = offset
898 .try_into_literal_int64()
899 .expect("Should be a Literal by this time")
900 .try_into()
901 .expect("Should have checked non-negativity of OFFSET clause already");
902 let mut result = input.top_k(
903 applied_group_key,
904 applied_order_key,
905 limit_mir,
906 offset,
907 expected_group_size,
908 );
909
910 // If new columns were added for `limit` we must remove them.
911 if old_arity != new_arity {
912 result = result.project((0..old_arity).collect());
913 }
914
915 result
916 }
917 Negate { input } => {
918 // Negate is uncomplicated.
919 input
920 .applied_to(id_gen, get_outer, col_map, cte_map, context)?
921 .negate()
922 }
923 Threshold { input } => {
924 // Threshold is uncomplicated.
925 input
926 .applied_to(id_gen, get_outer, col_map, cte_map, context)?
927 .threshold()
928 }
929 })
930 })
931 }
932}
933
934impl HirScalarExpr {
935 /// Rewrite `self` into a `mz_expr::ScalarExpr` which can be applied to the modified `inner`.
936 ///
937 /// This method is responsible for decorrelating subqueries in `self` by introducing further columns
938 /// to `inner`, and rewriting `self` to refer to its physical columns (specified by `usize` positions).
939 /// The most complicated logic is for the scalar expressions that involve subqueries, each of which are
940 /// documented in more detail closer to their logic.
941 ///
942 /// This process presumes that `inner` is the result of decorrelation, meaning its first several columns
943 /// may be inherited from outer relations. The `col_map` column map should provide specific offsets where
944 /// each of these references can be found.
945 fn applied_to(
946 self,
947 id_gen: &mut mz_ore::id_gen::IdGen,
948 col_map: &ColumnMap,
949 cte_map: &mut CteMap,
950 inner: &mut MirRelationExpr,
951 subquery_map: &Option<&BTreeMap<HirScalarExpr, usize>>,
952 context: &Context,
953 ) -> Result<MirScalarExpr, PlanError> {
954 maybe_grow(|| {
955 use MirScalarExpr as SS;
956
957 use HirScalarExpr::*;
958
959 if let Some(subquery_map) = subquery_map {
960 if let Some(col) = subquery_map.get(&self) {
961 return Ok(SS::column(*col));
962 }
963 }
964
965 Ok::<MirScalarExpr, PlanError>(match self {
966 Column(col_ref, name) => SS::Column(col_map.get(&col_ref), name),
967 Literal(row, typ, _name) => SS::Literal(Ok(row), typ),
968 Parameter(_, _name) => {
969 panic!("cannot decorrelate expression with unbound parameters")
970 }
971 CallUnmaterializable(func, _name) => SS::CallUnmaterializable(func),
972 CallUnary {
973 func,
974 expr,
975 name: _,
976 } => SS::CallUnary {
977 func,
978 expr: Box::new(expr.applied_to(
979 id_gen,
980 col_map,
981 cte_map,
982 inner,
983 subquery_map,
984 context,
985 )?),
986 },
987 CallBinary {
988 func,
989 expr1,
990 expr2,
991 name: _,
992 } => SS::CallBinary {
993 func,
994 expr1: Box::new(expr1.applied_to(
995 id_gen,
996 col_map,
997 cte_map,
998 inner,
999 subquery_map,
1000 context,
1001 )?),
1002 expr2: Box::new(expr2.applied_to(
1003 id_gen,
1004 col_map,
1005 cte_map,
1006 inner,
1007 subquery_map,
1008 context,
1009 )?),
1010 },
1011 CallVariadic {
1012 func,
1013 exprs,
1014 name: _,
1015 } => SS::CallVariadic {
1016 func,
1017 exprs: exprs
1018 .into_iter()
1019 .map(|expr| {
1020 expr.applied_to(id_gen, col_map, cte_map, inner, subquery_map, context)
1021 })
1022 .collect::<Result<Vec<_>, _>>()?,
1023 },
1024 If {
1025 cond,
1026 then,
1027 els,
1028 name,
1029 } => {
1030 // The `If` case is complicated by the fact that we do not want to
1031 // apply the `then` or `else` logic to tuples that respectively do
1032 // not or do pass the `cond` test. Our strategy is to independently
1033 // decorrelate the `then` and `else` logic, and apply each to tuples
1034 // that respectively pass and do not pass the `cond` logic (which is
1035 // executed, and so decorrelated, for all tuples).
1036 //
1037 // Informally, we turn the `if` statement into:
1038 //
1039 // let then_case = inner.filter(cond).map(then);
1040 // let else_case = inner.filter(!cond).map(else);
1041 // return then_case.concat(else_case);
1042 //
1043 // We only require this if either expression would result in any
1044 // computation beyond the expr itself, which we will interpret as
1045 // "introduces additional columns". In the absence of correlation,
1046 // we should just retain a `ScalarExpr::If` expression; the inverse
1047 // transformation as above is complicated to recover after the fact,
1048 // and we would benefit from not introducing the complexity.
1049
1050 let inner_arity = inner.arity();
1051 let cond_expr =
1052 cond.applied_to(id_gen, col_map, cte_map, inner, subquery_map, context)?;
1053
1054 // Defensive copies, in case we mangle these in decorrelation.
1055 let inner_clone = inner.clone();
1056 let then_clone = then.clone();
1057 let else_clone = els.clone();
1058
1059 let cond_arity = inner.arity();
1060 let then_expr =
1061 then.applied_to(id_gen, col_map, cte_map, inner, subquery_map, context)?;
1062 let else_expr =
1063 els.applied_to(id_gen, col_map, cte_map, inner, subquery_map, context)?;
1064
1065 if cond_arity == inner.arity() {
1066 // If no additional columns were added, we simply return the
1067 // `If` variant with the updated expressions.
1068 SS::If {
1069 cond: Box::new(cond_expr),
1070 then: Box::new(then_expr),
1071 els: Box::new(else_expr),
1072 }
1073 } else {
1074 // If columns were added, we need a more careful approach, as
1075 // described above. First, we need to de-correlate each of
1076 // the two expressions independently, and apply their cases
1077 // as `MirRelationExpr::Map` operations.
1078
1079 *inner = inner_clone.let_in(id_gen, |id_gen, get_inner| {
1080 // Restrict to records satisfying `cond_expr` and apply `then` as a map.
1081 let mut then_inner = get_inner.clone().filter(vec![cond_expr.clone()]);
1082 let then_expr = then_clone.applied_to(
1083 id_gen,
1084 col_map,
1085 cte_map,
1086 &mut then_inner,
1087 subquery_map,
1088 context,
1089 )?;
1090 let then_arity = then_inner.arity();
1091 then_inner = then_inner
1092 .map_one(then_expr)
1093 .project((0..inner_arity).chain(Some(then_arity)).collect());
1094
1095 // Restrict to records not satisfying `cond_expr` and apply `els` as a map.
1096 let mut else_inner = get_inner.filter(vec![SS::CallVariadic {
1097 func: mz_expr::VariadicFunc::Or,
1098 exprs: vec![
1099 cond_expr.clone().call_binary(SS::literal_false(), func::Eq),
1100 cond_expr.clone().call_is_null(),
1101 ],
1102 }]);
1103 let else_expr = else_clone.applied_to(
1104 id_gen,
1105 col_map,
1106 cte_map,
1107 &mut else_inner,
1108 subquery_map,
1109 context,
1110 )?;
1111 let else_arity = else_inner.arity();
1112 else_inner = else_inner
1113 .map_one(else_expr)
1114 .project((0..inner_arity).chain(Some(else_arity)).collect());
1115
1116 // concatenate the two results.
1117 Ok::<MirRelationExpr, PlanError>(then_inner.union(else_inner))
1118 })?;
1119
1120 SS::Column(inner_arity, name)
1121 }
1122 }
1123
1124 // Subqueries!
1125 // These are surprisingly subtle. Things to be careful of:
1126
1127 // Anything in the subquery that cares about row counts (Reduce/Distinct/Negate/Threshold) must not:
1128 // * change the row counts of the outer query
1129 // * accidentally compute its own value using the row counts of the outer query
1130 // Use `branch` to calculate the subquery once for each __distinct__ key in the outer
1131 // query and then join the answers back on to the original rows of the outer query.
1132
1133 // When the subquery would return 0 rows for some row in the outer query, `subquery.applied_to(get_inner)` will not have any corresponding row.
1134 // Use `lookup` if you need to add default values for cases when the subquery returns 0 rows.
1135 Exists(expr, name) => {
1136 let apply_requires_distinct_outer = true;
1137 *inner = apply_existential_subquery(
1138 id_gen,
1139 inner.take_dangerous(),
1140 col_map,
1141 cte_map,
1142 *expr,
1143 apply_requires_distinct_outer,
1144 context,
1145 )?;
1146 SS::Column(inner.arity() - 1, name)
1147 }
1148
1149 Select(expr, name) => {
1150 let apply_requires_distinct_outer = true;
1151 *inner = apply_scalar_subquery(
1152 id_gen,
1153 inner.take_dangerous(),
1154 col_map,
1155 cte_map,
1156 *expr,
1157 apply_requires_distinct_outer,
1158 context,
1159 )?;
1160 SS::Column(inner.arity() - 1, name)
1161 }
1162 Windowing(expr, _name) => {
1163 let partition_by = expr.partition_by;
1164 let order_by = expr.order_by;
1165
1166 // argument lowering for scalar window functions
1167 // (We need to specify the & _ in the arguments because of this problem:
1168 // https://users.rust-lang.org/t/the-implementation-of-fnonce-is-not-general-enough/72141/3 )
1169 let scalar_lower_args =
1170 |_id_gen: &mut _,
1171 _col_map: &_,
1172 _cte_map: &mut _,
1173 _get_inner: &mut _,
1174 _subquery_map: &Option<&_>,
1175 order_by_mir: Vec<MirScalarExpr>,
1176 original_row_record,
1177 original_row_record_type: SqlScalarType| {
1178 let agg_input = MirScalarExpr::CallVariadic {
1179 func: mz_expr::VariadicFunc::ListCreate {
1180 elem_type: original_row_record_type.clone(),
1181 },
1182 exprs: vec![original_row_record],
1183 };
1184 let mut agg_input = vec![agg_input];
1185 agg_input.extend(order_by_mir.clone());
1186 let agg_input = MirScalarExpr::CallVariadic {
1187 func: mz_expr::VariadicFunc::RecordCreate {
1188 field_names: (0..agg_input.len())
1189 .map(|_| ColumnName::from(UNKNOWN_COLUMN_NAME))
1190 .collect_vec(),
1191 },
1192 exprs: agg_input,
1193 };
1194 let list_type = SqlScalarType::List {
1195 element_type: Box::new(original_row_record_type),
1196 custom_id: None,
1197 };
1198 let agg_input_type = SqlScalarType::Record {
1199 fields: std::iter::once(&list_type)
1200 .map(|t| {
1201 (
1202 ColumnName::from(UNKNOWN_COLUMN_NAME),
1203 t.clone().nullable(false),
1204 )
1205 })
1206 .collect(),
1207 custom_id: None,
1208 }
1209 .nullable(false);
1210
1211 Ok((agg_input, agg_input_type))
1212 };
1213
1214 // argument lowering for value window functions and aggregate window functions
1215 let value_or_aggr_lower_args = |hir_encoded_args: Box<HirScalarExpr>| {
1216 |id_gen: &mut _,
1217 col_map: &_,
1218 cte_map: &mut _,
1219 get_inner: &mut _,
1220 subquery_map: &Option<&_>,
1221 order_by_mir: Vec<MirScalarExpr>,
1222 original_row_record,
1223 original_row_record_type| {
1224 // Creates [((OriginalRow, EncodedArgs), OrderByExprs...)]
1225
1226 // Compute the encoded args for all rows
1227 let mir_encoded_args = hir_encoded_args.applied_to(
1228 id_gen,
1229 col_map,
1230 cte_map,
1231 get_inner,
1232 subquery_map,
1233 context,
1234 )?;
1235 let mir_encoded_args_type = mir_encoded_args
1236 .typ(&get_inner.typ().column_types)
1237 .scalar_type;
1238
1239 // Build a new record that has two fields:
1240 // 1. the original row in a record
1241 // 2. the encoded args (which can be either a single value, or a record
1242 // if the window function has multiple arguments, such as `lag`)
1243 let fn_input_record_fields: Box<[_]> =
1244 [original_row_record_type, mir_encoded_args_type]
1245 .iter()
1246 .map(|t| {
1247 (
1248 ColumnName::from(UNKNOWN_COLUMN_NAME),
1249 t.clone().nullable(false),
1250 )
1251 })
1252 .collect();
1253 let fn_input_record = MirScalarExpr::CallVariadic {
1254 func: mz_expr::VariadicFunc::RecordCreate {
1255 field_names: fn_input_record_fields
1256 .iter()
1257 .map(|(n, _)| n.clone())
1258 .collect_vec(),
1259 },
1260 exprs: vec![original_row_record, mir_encoded_args],
1261 };
1262 let fn_input_record_type = SqlScalarType::Record {
1263 fields: fn_input_record_fields,
1264 custom_id: None,
1265 }
1266 .nullable(false);
1267
1268 // Build a new record with the record above + the ORDER BY exprs
1269 // This follows the standard encoding of ORDER BY exprs used by aggregate functions
1270 let mut agg_input = vec![fn_input_record];
1271 agg_input.extend(order_by_mir.clone());
1272 let agg_input = MirScalarExpr::CallVariadic {
1273 func: mz_expr::VariadicFunc::RecordCreate {
1274 field_names: (0..agg_input.len())
1275 .map(|_| ColumnName::from(UNKNOWN_COLUMN_NAME))
1276 .collect_vec(),
1277 },
1278 exprs: agg_input,
1279 };
1280
1281 let agg_input_type = SqlScalarType::Record {
1282 fields: [(
1283 ColumnName::from(UNKNOWN_COLUMN_NAME),
1284 fn_input_record_type.nullable(false),
1285 )]
1286 .into(),
1287 custom_id: None,
1288 }
1289 .nullable(false);
1290
1291 Ok((agg_input, agg_input_type))
1292 }
1293 };
1294
1295 match expr.func {
1296 WindowExprType::Scalar(scalar_window_expr) => {
1297 let mir_aggr_func = scalar_window_expr.into_expr();
1298 Self::window_func_applied_to(
1299 id_gen,
1300 col_map,
1301 cte_map,
1302 inner,
1303 subquery_map,
1304 partition_by,
1305 order_by,
1306 mir_aggr_func,
1307 scalar_lower_args,
1308 context,
1309 )?
1310 }
1311 WindowExprType::Value(value_window_expr) => {
1312 let (hir_encoded_args, mir_aggr_func) = value_window_expr.into_expr();
1313
1314 Self::window_func_applied_to(
1315 id_gen,
1316 col_map,
1317 cte_map,
1318 inner,
1319 subquery_map,
1320 partition_by,
1321 order_by,
1322 mir_aggr_func,
1323 value_or_aggr_lower_args(hir_encoded_args),
1324 context,
1325 )?
1326 }
1327 WindowExprType::Aggregate(aggr_window_expr) => {
1328 let (hir_encoded_args, mir_aggr_func) = aggr_window_expr.into_expr();
1329
1330 Self::window_func_applied_to(
1331 id_gen,
1332 col_map,
1333 cte_map,
1334 inner,
1335 subquery_map,
1336 partition_by,
1337 order_by,
1338 mir_aggr_func,
1339 value_or_aggr_lower_args(hir_encoded_args),
1340 context,
1341 )?
1342 }
1343 }
1344 }
1345 })
1346 })
1347 }
1348
1349 fn window_func_applied_to<F>(
1350 id_gen: &mut mz_ore::id_gen::IdGen,
1351 col_map: &ColumnMap,
1352 cte_map: &mut CteMap,
1353 inner: &mut MirRelationExpr,
1354 subquery_map: &Option<&BTreeMap<HirScalarExpr, usize>>,
1355 partition_by: Vec<HirScalarExpr>,
1356 order_by: Vec<HirScalarExpr>,
1357 mir_aggr_func: AggregateFunc,
1358 lower_args: F,
1359 context: &Context,
1360 ) -> Result<MirScalarExpr, PlanError>
1361 where
1362 F: FnOnce(
1363 &mut mz_ore::id_gen::IdGen,
1364 &ColumnMap,
1365 &mut CteMap,
1366 &mut MirRelationExpr,
1367 &Option<&BTreeMap<HirScalarExpr, usize>>,
1368 Vec<MirScalarExpr>,
1369 MirScalarExpr,
1370 SqlScalarType,
1371 ) -> Result<(MirScalarExpr, SqlColumnType), PlanError>,
1372 {
1373 // Example MIRs for a window function (specifically, a window aggregation):
1374 //
1375 // CREATE TABLE t7(x INT, y INT);
1376 //
1377 // explain decorrelated plan for select sum(x*y) over (partition by x+y order by x-y, x/y) from t7;
1378 //
1379 // Decorrelated Plan
1380 // Project (#3)
1381 // Map (#2)
1382 // Project (#3..=#5)
1383 // Map (record_get[0](record_get[1](#2)), record_get[1](record_get[1](#2)), record_get[0](#2))
1384 // FlatMap unnest_list(#1)
1385 // Reduce group_by=[#2] aggregates=[window_agg[sum order_by=[#0 asc nulls_last, #1 asc nulls_last]](row(row(row(#0, #1), (#0 * #1)), (#0 - #1), (#0 / #1)))]
1386 // Map ((#0 + #1))
1387 // CrossJoin
1388 // Constant
1389 // - ()
1390 // Get materialize.public.t7
1391 //
1392 // The same query after optimizations:
1393 //
1394 // explain select sum(x*y) over (partition by x+y order by x-y, x/y) from t7;
1395 //
1396 // Optimized Plan
1397 // Explained Query:
1398 // Project (#2)
1399 // Map (record_get[0](#1))
1400 // FlatMap unnest_list(#0)
1401 // Project (#1)
1402 // Reduce group_by=[(#0 + #1)] aggregates=[window_agg[sum order_by=[#0 asc nulls_last, #1 asc nulls_last]](row(row(row(#0, #1), (#0 * #1)), (#0 - #1), (#0 / #1)))]
1403 // ReadStorage materialize.public.t7
1404 //
1405 // The `row(row(row(...), ...), ...)` stuff means the following:
1406 // `row(row(row(<original row>), <arguments to window function>), <order by values>...)`
1407 // - The <arguments to window function> can be either a single value or itself a
1408 // `row` if there are multiple arguments.
1409 // - The <order by values> are _not_ wrapped in a `row`, even if there are more than one
1410 // ORDER BY columns.
1411 // - The <original row> currently always captures the entire original row. This should
1412 // improve when we make `ProjectionPushdown` smarter, see
1413 // https://github.com/MaterializeInc/database-issues/issues/5090
1414 //
1415 // TODO:
1416 // We should probably introduce some dedicated Datum constructor functions instead of `row`
1417 // to make MIR plans and MIR construction/manipulation code more readable. Additionally, we
1418 // might even introduce dedicated Datum enum variants, so that the rendering code also
1419 // becomes more readable (and possibly slightly more performant).
1420
1421 *inner = inner
1422 .take_dangerous()
1423 .let_in(id_gen, |id_gen, mut get_inner| {
1424 let order_by_mir = order_by
1425 .into_iter()
1426 .map(|o| {
1427 o.applied_to(
1428 id_gen,
1429 col_map,
1430 cte_map,
1431 &mut get_inner,
1432 subquery_map,
1433 context,
1434 )
1435 })
1436 .collect::<Result<Vec<_>, _>>()?;
1437
1438 // Record input arity here so that any group_keys that need to mutate get_inner
1439 // don't add those columns to the aggregate input.
1440 let input_type = get_inner.typ();
1441 let input_arity = input_type.arity();
1442 // The reduction that computes the window function must be keyed on the columns
1443 // from the outer context, plus the expressions in the partition key. The current
1444 // subquery will be 'executed' for every distinct row from the outer context so
1445 // by putting the outer columns in the grouping key we isolate each re-execution.
1446 let mut group_key = col_map
1447 .inner
1448 .iter()
1449 .map(|(_, outer_col)| *outer_col)
1450 .sorted()
1451 .collect_vec();
1452 for p in partition_by {
1453 let key = p.applied_to(
1454 id_gen,
1455 col_map,
1456 cte_map,
1457 &mut get_inner,
1458 subquery_map,
1459 context,
1460 )?;
1461 if let MirScalarExpr::Column(c, _name) = key {
1462 group_key.push(c);
1463 } else {
1464 get_inner = get_inner.map_one(key);
1465 group_key.push(get_inner.arity() - 1);
1466 }
1467 }
1468
1469 get_inner.let_in(id_gen, |id_gen, mut get_inner| {
1470 // Original columns of the relation
1471 let fields: Box<_> = input_type
1472 .column_types
1473 .iter()
1474 .take(input_arity)
1475 .map(|t| (ColumnName::from(UNKNOWN_COLUMN_NAME), t.clone()))
1476 .collect();
1477
1478 // Original row made into a record
1479 let original_row_record = MirScalarExpr::CallVariadic {
1480 func: mz_expr::VariadicFunc::RecordCreate {
1481 field_names: fields.iter().map(|(name, _)| name.clone()).collect_vec(),
1482 },
1483 exprs: (0..input_arity).map(MirScalarExpr::column).collect_vec(),
1484 };
1485 let original_row_record_type = SqlScalarType::Record {
1486 fields,
1487 custom_id: None,
1488 };
1489
1490 let (agg_input, agg_input_type) = lower_args(
1491 id_gen,
1492 col_map,
1493 cte_map,
1494 &mut get_inner,
1495 subquery_map,
1496 order_by_mir,
1497 original_row_record,
1498 original_row_record_type,
1499 )?;
1500
1501 let aggregate = mz_expr::AggregateExpr {
1502 func: mir_aggr_func,
1503 expr: agg_input,
1504 distinct: false,
1505 };
1506
1507 // Actually call reduce with the window function
1508 // The output of the aggregation function should be a list of tuples that has
1509 // the result in the first position, and the original row in the second position
1510 let mut reduce = get_inner
1511 .reduce(group_key.clone(), vec![aggregate.clone()], None)
1512 .flat_map(
1513 mz_expr::TableFunc::UnnestList {
1514 el_typ: aggregate
1515 .func
1516 .output_type(agg_input_type)
1517 .scalar_type
1518 .unwrap_list_element_type()
1519 .clone(),
1520 },
1521 vec![MirScalarExpr::column(group_key.len())],
1522 );
1523 let record_col = reduce.arity() - 1;
1524
1525 // Unpack the record output by the window function
1526 for c in 0..input_arity {
1527 reduce = reduce.take_dangerous().map_one(MirScalarExpr::CallUnary {
1528 func: mz_expr::UnaryFunc::RecordGet(mz_expr::func::RecordGet(c)),
1529 expr: Box::new(MirScalarExpr::CallUnary {
1530 func: mz_expr::UnaryFunc::RecordGet(mz_expr::func::RecordGet(1)),
1531 expr: Box::new(MirScalarExpr::column(record_col)),
1532 }),
1533 });
1534 }
1535
1536 // Append the column with the result of the window function.
1537 reduce = reduce.take_dangerous().map_one(MirScalarExpr::CallUnary {
1538 func: mz_expr::UnaryFunc::RecordGet(mz_expr::func::RecordGet(0)),
1539 expr: Box::new(MirScalarExpr::column(record_col)),
1540 });
1541
1542 let agg_col = record_col + 1 + input_arity;
1543 Ok::<_, PlanError>(reduce.project((record_col + 1..agg_col + 1).collect_vec()))
1544 })
1545 })?;
1546 Ok(MirScalarExpr::column(inner.arity() - 1))
1547 }
1548
1549 /// Applies the subqueries in the given list of scalar expressions to every distinct
1550 /// value of the given relation and returns a join of the given relation with all
1551 /// the subqueries found, and the mapping of scalar expressions with columns projected
1552 /// by the returned join that will hold their results.
1553 fn lower_subqueries(
1554 exprs: &[Self],
1555 id_gen: &mut mz_ore::id_gen::IdGen,
1556 col_map: &ColumnMap,
1557 cte_map: &mut CteMap,
1558 inner: MirRelationExpr,
1559 context: &Context,
1560 ) -> Result<(MirRelationExpr, BTreeMap<HirScalarExpr, usize>), PlanError> {
1561 let mut subquery_map = BTreeMap::new();
1562 let output = inner.let_in(id_gen, |id_gen, get_inner| {
1563 let mut subqueries = Vec::new();
1564 let distinct_inner = get_inner.clone().distinct();
1565 for expr in exprs.iter() {
1566 expr.visit_pre_post(
1567 &mut |e| match e {
1568 // For simplicity, subqueries within a conditional statement will be
1569 // lowered when lowering the conditional expression.
1570 HirScalarExpr::If { .. } => Some(vec![]),
1571 _ => None,
1572 },
1573 &mut |e| match e {
1574 HirScalarExpr::Select(expr, _name) => {
1575 let apply_requires_distinct_outer = false;
1576 let subquery = apply_scalar_subquery(
1577 id_gen,
1578 distinct_inner.clone(),
1579 col_map,
1580 cte_map,
1581 (**expr).clone(),
1582 apply_requires_distinct_outer,
1583 context,
1584 )
1585 .unwrap();
1586
1587 subqueries.push((e.clone(), subquery));
1588 }
1589 HirScalarExpr::Exists(expr, _name) => {
1590 let apply_requires_distinct_outer = false;
1591 let subquery = apply_existential_subquery(
1592 id_gen,
1593 distinct_inner.clone(),
1594 col_map,
1595 cte_map,
1596 (**expr).clone(),
1597 apply_requires_distinct_outer,
1598 context,
1599 )
1600 .unwrap();
1601 subqueries.push((e.clone(), subquery));
1602 }
1603 _ => {}
1604 },
1605 )?;
1606 }
1607
1608 if subqueries.is_empty() {
1609 Ok::<MirRelationExpr, PlanError>(get_inner)
1610 } else {
1611 let inner_arity = get_inner.arity();
1612 let mut total_arity = inner_arity;
1613 let mut join_inputs = vec![get_inner];
1614 let mut join_input_arities = vec![inner_arity];
1615 for (expr, subquery) in subqueries.into_iter() {
1616 // Avoid lowering duplicated subqueries
1617 if !subquery_map.contains_key(&expr) {
1618 let subquery_arity = subquery.arity();
1619 assert_eq!(subquery_arity, inner_arity + 1);
1620 join_inputs.push(subquery);
1621 join_input_arities.push(subquery_arity);
1622 total_arity += subquery_arity;
1623
1624 // Column with the value of the subquery
1625 subquery_map.insert(expr, total_arity - 1);
1626 }
1627 }
1628 // Each subquery projects all the columns of the outer context (distinct_inner)
1629 // plus 1 column, containing the result of the subquery. Those columns must be
1630 // joined with the outer/main relation (get_inner).
1631 let input_mapper =
1632 mz_expr::JoinInputMapper::new_from_input_arities(join_input_arities);
1633 let equivalences = (0..inner_arity)
1634 .map(|col| {
1635 join_inputs
1636 .iter()
1637 .enumerate()
1638 .map(|(input, _)| {
1639 MirScalarExpr::column(input_mapper.map_column_to_global(col, input))
1640 })
1641 .collect_vec()
1642 })
1643 .collect_vec();
1644 Ok(MirRelationExpr::join_scalars(join_inputs, equivalences))
1645 }
1646 })?;
1647 Ok((output, subquery_map))
1648 }
1649
1650 /// Rewrites `self` into a `mz_expr::ScalarExpr`.
1651 ///
1652 /// Returns an _internal_ error if the expression contains
1653 /// - a subquery
1654 /// - a column reference to an outer level
1655 /// - a parameter
1656 /// - a window function call
1657 ///
1658 /// Should succeed if [`HirScalarExpr::is_constant`] would return true on `self`.
1659 pub fn lower_uncorrelated(self) -> Result<MirScalarExpr, PlanError> {
1660 use MirScalarExpr as SS;
1661
1662 use HirScalarExpr::*;
1663
1664 Ok(match self {
1665 Column(ColumnRef { level: 0, column }, name) => SS::Column(column, name),
1666 Literal(datum, typ, _name) => SS::Literal(Ok(datum), typ),
1667 CallUnmaterializable(func, _name) => SS::CallUnmaterializable(func),
1668 CallUnary {
1669 func,
1670 expr,
1671 name: _,
1672 } => SS::CallUnary {
1673 func,
1674 expr: Box::new(expr.lower_uncorrelated()?),
1675 },
1676 CallBinary {
1677 func,
1678 expr1,
1679 expr2,
1680 name: _,
1681 } => SS::CallBinary {
1682 func,
1683 expr1: Box::new(expr1.lower_uncorrelated()?),
1684 expr2: Box::new(expr2.lower_uncorrelated()?),
1685 },
1686 CallVariadic {
1687 func,
1688 exprs,
1689 name: _,
1690 } => SS::CallVariadic {
1691 func,
1692 exprs: exprs
1693 .into_iter()
1694 .map(|expr| expr.lower_uncorrelated())
1695 .collect::<Result<_, _>>()?,
1696 },
1697 If {
1698 cond,
1699 then,
1700 els,
1701 name: _,
1702 } => SS::If {
1703 cond: Box::new(cond.lower_uncorrelated()?),
1704 then: Box::new(then.lower_uncorrelated()?),
1705 els: Box::new(els.lower_uncorrelated()?),
1706 },
1707 Select { .. } | Exists { .. } | Parameter(..) | Column(..) | Windowing(..) => {
1708 sql_bail!(
1709 "Internal error: unexpected HirScalarExpr in lower_uncorrelated: {:?}",
1710 self
1711 );
1712 }
1713 })
1714 }
1715}
1716
1717/// Prepare to apply `inner` to `outer`. Note that `inner` is a correlated (SQL)
1718/// expression, while `outer` is a non-correlated (dataflow) expression. `inner`
1719/// will, in effect, be executed once for every distinct row in `outer`, and the
1720/// results will be joined with `outer`. Note that columns in `outer` that are
1721/// not depended upon by `inner` are thrown away before the distinct, so that we
1722/// don't perform needless computation of `inner`.
1723///
1724/// `branch` will inspect the contents of `inner` to determine whether `inner`
1725/// is not multiplicity sensitive (roughly, contains only maps, filters,
1726/// projections, and calls to table functions). If it is not multiplicity
1727/// sensitive, `branch` will *not* distinctify outer. If this is problematic,
1728/// e.g. because the `apply` callback itself introduces multiplicity-sensitive
1729/// operations that were not present in `inner`, then set
1730/// `apply_requires_distinct_outer` to ensure that `branch` chooses the plan
1731/// that distinctifies `outer`.
1732///
1733/// The caller must supply the `apply` function that applies the rewritten
1734/// `inner` to `outer`.
1735fn branch<F>(
1736 id_gen: &mut mz_ore::id_gen::IdGen,
1737 outer: MirRelationExpr,
1738 col_map: &ColumnMap,
1739 cte_map: &mut CteMap,
1740 inner: HirRelationExpr,
1741 apply_requires_distinct_outer: bool,
1742 context: &Context,
1743 apply: F,
1744) -> Result<MirRelationExpr, PlanError>
1745where
1746 F: FnOnce(
1747 &mut mz_ore::id_gen::IdGen,
1748 HirRelationExpr,
1749 MirRelationExpr,
1750 &ColumnMap,
1751 &mut CteMap,
1752 &Context,
1753 ) -> Result<MirRelationExpr, PlanError>,
1754{
1755 // TODO: It would be nice to have a version of this code w/o optimizations,
1756 // at the least for purposes of understanding. It was difficult for one reader
1757 // to understand the required properties of `outer` and `col_map`.
1758
1759 // If the inner expression is sufficiently simple, it is safe to apply it
1760 // *directly* to outer, rather than applying it to the distinctified key
1761 // (see below).
1762 //
1763 // As an example, consider the following two queries:
1764 //
1765 // CREATE TABLE t (a int, b int);
1766 // SELECT a, series FROM t, generate_series(1, t.b) series;
1767 //
1768 // The "simple" path for the `SELECT` yields
1769 //
1770 // %0 =
1771 // | Get t
1772 // | FlatMap generate_series(1, #1)
1773 //
1774 // while the non-simple path yields:
1775 //
1776 // %0 =
1777 // | Get t
1778 //
1779 // %1 =
1780 // | Get t
1781 // | Distinct group=(#1)
1782 // | FlatMap generate_series(1, #0)
1783 //
1784 // %2 =
1785 // | LeftJoin %1 %2 (= #1 #2)
1786 //
1787 // There is a tradeoff here: the simple plan is stateless, but the non-
1788 // simple plan may do (much) less computation if there are only a few
1789 // distinct values of `t.b`.
1790 //
1791 // We apply a very simple heuristic here and take the simple path if `inner`
1792 // contains only maps, filters, projections, and calls to table functions.
1793 // The intuition is that straightforward usage of table functions should
1794 // take the simple path, while everything else should not. (In theory we
1795 // think this transformation is valid as long as `inner` does not contain a
1796 // Reduce, Distinct, or TopK node, but it is not always an optimization in
1797 // the general case.)
1798 //
1799 // TODO(benesch): this should all be handled by a proper optimizer, but
1800 // detecting the moment of decorrelation in the optimizer right now is too
1801 // hard.
1802 let mut is_simple = true;
1803 #[allow(deprecated)]
1804 inner.visit(0, &mut |expr, _| match expr {
1805 HirRelationExpr::Constant { .. }
1806 | HirRelationExpr::Project { .. }
1807 | HirRelationExpr::Map { .. }
1808 | HirRelationExpr::Filter { .. }
1809 | HirRelationExpr::CallTable { .. } => (),
1810 _ => is_simple = false,
1811 });
1812 if is_simple && !apply_requires_distinct_outer {
1813 let new_col_map = col_map.enter_scope(outer.arity() - col_map.len());
1814 return outer.let_in(id_gen, |id_gen, get_outer| {
1815 apply(id_gen, inner, get_outer, &new_col_map, cte_map, context)
1816 });
1817 }
1818
1819 // The key consists of the columns from the outer expression upon which the
1820 // inner relation depends. We discover these dependencies by walking the
1821 // inner relation expression and looking for column references whose level
1822 // escapes inner.
1823 //
1824 // At the end of this process, `key` contains the decorrelated position of
1825 // each outer column, according to the passed-in `col_map`, and
1826 // `new_col_map` maps each outer column to its new ordinal position in key.
1827 let mut outer_cols = BTreeSet::new();
1828 #[allow(deprecated)]
1829 inner.visit_columns(0, &mut |depth, col| {
1830 // Test if the column reference escapes the subquery.
1831 if col.level > depth {
1832 outer_cols.insert(ColumnRef {
1833 level: col.level - depth,
1834 column: col.column,
1835 });
1836 }
1837 });
1838 // Collect all the outer columns referenced by any CTE referenced by
1839 // the inner relation.
1840 #[allow(deprecated)]
1841 inner.visit(0, &mut |e, _| match e {
1842 HirRelationExpr::Get {
1843 id: mz_expr::Id::Local(id),
1844 ..
1845 } => {
1846 if let Some(cte_desc) = cte_map.get(id) {
1847 let cte_outer_arity = cte_desc.outer_relation.arity();
1848 outer_cols.extend(
1849 col_map
1850 .inner
1851 .iter()
1852 .filter(|(_, position)| **position < cte_outer_arity)
1853 .map(|(c, _)| {
1854 // `col_map` maps column references to column positions in
1855 // `outer`'s projection.
1856 // `outer_cols` is meant to contain the external column
1857 // references in `inner`.
1858 // Since `inner` defines a new scope, any column reference
1859 // in `col_map` is one level deeper when seen from within
1860 // `inner`, hence the +1.
1861 ColumnRef {
1862 level: c.level + 1,
1863 column: c.column,
1864 }
1865 }),
1866 );
1867 }
1868 }
1869 HirRelationExpr::Let { id, .. } => {
1870 // Note: if ID uniqueness is not guaranteed, we can't use `visit` since
1871 // we would need to remove the old CTE with the same ID temporarily while
1872 // traversing the definition of the new CTE under the same ID.
1873 assert!(!cte_map.contains_key(id));
1874 }
1875 _ => {}
1876 });
1877 let mut new_col_map = BTreeMap::new();
1878 let mut key = vec![];
1879 for col in outer_cols {
1880 new_col_map.insert(col, key.len());
1881 key.push(col_map.get(&ColumnRef {
1882 // Note: `outer_cols` contains the external column references within `inner`.
1883 // We must compensate for `inner`'s scope when translating column references
1884 // as seen within `inner` to column references as seen from `outer`'s context,
1885 // hence the -1.
1886 level: col.level - 1,
1887 column: col.column,
1888 }));
1889 }
1890 let new_col_map = ColumnMap::new(new_col_map);
1891 outer.let_in(id_gen, |id_gen, get_outer| {
1892 let keyed_outer = if key.is_empty() {
1893 // Don't depend on outer at all if the branch is not correlated,
1894 // which yields vastly better query plans. Note that this is a bit
1895 // weird in that the branch will be computed even if outer has no
1896 // rows, whereas if it had been correlated it would not (and *could*
1897 // not) have been computed if outer had no rows, but the callers of
1898 // this function don't mind these somewhat-weird semantics.
1899 MirRelationExpr::constant(vec![vec![]], SqlRelationType::new(vec![]))
1900 } else {
1901 get_outer.clone().distinct_by(key.clone())
1902 };
1903 keyed_outer.let_in(id_gen, |id_gen, get_keyed_outer| {
1904 let oa = get_outer.arity();
1905 let branch = apply(
1906 id_gen,
1907 inner,
1908 get_keyed_outer,
1909 &new_col_map,
1910 cte_map,
1911 context,
1912 )?;
1913 let ba = branch.arity();
1914 let joined = MirRelationExpr::join(
1915 vec![get_outer.clone(), branch],
1916 key.iter()
1917 .enumerate()
1918 .map(|(i, &k)| vec![(0, k), (1, i)])
1919 .collect(),
1920 )
1921 // throw away the right-hand copy of the key we just joined on
1922 .project((0..oa).chain((oa + key.len())..(oa + ba)).collect());
1923 Ok(joined)
1924 })
1925 })
1926}
1927
1928fn apply_scalar_subquery(
1929 id_gen: &mut mz_ore::id_gen::IdGen,
1930 outer: MirRelationExpr,
1931 col_map: &ColumnMap,
1932 cte_map: &mut CteMap,
1933 scalar_subquery: HirRelationExpr,
1934 apply_requires_distinct_outer: bool,
1935 context: &Context,
1936) -> Result<MirRelationExpr, PlanError> {
1937 branch(
1938 id_gen,
1939 outer,
1940 col_map,
1941 cte_map,
1942 scalar_subquery,
1943 apply_requires_distinct_outer,
1944 context,
1945 |id_gen, expr, get_inner, col_map, cte_map, context| {
1946 // compute for every row in get_inner
1947 let select = expr.applied_to(id_gen, get_inner.clone(), col_map, cte_map, context)?;
1948 let col_type = select.typ().column_types.into_last();
1949
1950 let inner_arity = get_inner.arity();
1951 // We must determine a count for each `get_inner` prefix,
1952 // and report an error if that count exceeds one.
1953 let guarded = select.let_in(id_gen, |_id_gen, get_select| {
1954 // Count for each `get_inner` prefix.
1955 let counts = get_select.clone().reduce(
1956 (0..inner_arity).collect::<Vec<_>>(),
1957 vec![mz_expr::AggregateExpr {
1958 func: mz_expr::AggregateFunc::Count,
1959 expr: MirScalarExpr::literal_true(),
1960 distinct: false,
1961 }],
1962 None,
1963 );
1964
1965 let use_guard = context.config.enable_guard_subquery_tablefunc;
1966
1967 // Errors should result from counts > 1.
1968 let errors = if use_guard {
1969 counts
1970 .flat_map(
1971 mz_expr::TableFunc::GuardSubquerySize {
1972 column_type: col_type.clone().scalar_type,
1973 },
1974 vec![MirScalarExpr::column(inner_arity)],
1975 )
1976 .project(
1977 (0..inner_arity)
1978 .chain(Some(inner_arity + 1))
1979 .collect::<Vec<_>>(),
1980 )
1981 } else {
1982 counts
1983 .filter(vec![MirScalarExpr::column(inner_arity).call_binary(
1984 MirScalarExpr::literal_ok(Datum::Int64(1), SqlScalarType::Int64),
1985 func::Gt,
1986 )])
1987 .project((0..inner_arity).collect::<Vec<_>>())
1988 .map_one(MirScalarExpr::literal(
1989 Err(mz_expr::EvalError::MultipleRowsFromSubquery),
1990 col_type.clone().scalar_type,
1991 ))
1992 };
1993 // Return `get_select` and any errors added in.
1994 Ok::<_, PlanError>(get_select.union(errors))
1995 })?;
1996 // append Null to anything that didn't return any rows
1997 let default = vec![(Datum::Null, col_type.scalar_type)];
1998 get_inner.lookup(id_gen, guarded, default)
1999 },
2000 )
2001}
2002
2003fn apply_existential_subquery(
2004 id_gen: &mut mz_ore::id_gen::IdGen,
2005 outer: MirRelationExpr,
2006 col_map: &ColumnMap,
2007 cte_map: &mut CteMap,
2008 subquery_expr: HirRelationExpr,
2009 apply_requires_distinct_outer: bool,
2010 context: &Context,
2011) -> Result<MirRelationExpr, PlanError> {
2012 branch(
2013 id_gen,
2014 outer,
2015 col_map,
2016 cte_map,
2017 subquery_expr,
2018 apply_requires_distinct_outer,
2019 context,
2020 |id_gen, expr, get_inner, col_map, cte_map, context| {
2021 let exists = expr
2022 // compute for every row in get_inner
2023 .applied_to(id_gen, get_inner.clone(), col_map, cte_map, context)?
2024 // throw away actual values and just remember whether or not there were __any__ rows
2025 .distinct_by((0..get_inner.arity()).collect())
2026 // Append true to anything that returned any rows.
2027 .map(vec![MirScalarExpr::literal_true()]);
2028
2029 // append False to anything that didn't return any rows
2030 get_inner.lookup(id_gen, exists, vec![(Datum::False, SqlScalarType::Bool)])
2031 },
2032 )
2033}
2034
2035impl AggregateExpr {
2036 fn applied_to(
2037 self,
2038 id_gen: &mut mz_ore::id_gen::IdGen,
2039 col_map: &ColumnMap,
2040 cte_map: &mut CteMap,
2041 inner: &mut MirRelationExpr,
2042 context: &Context,
2043 ) -> Result<mz_expr::AggregateExpr, PlanError> {
2044 let AggregateExpr {
2045 func,
2046 expr,
2047 distinct,
2048 } = self;
2049
2050 Ok(mz_expr::AggregateExpr {
2051 func: func.into_expr(),
2052 expr: expr.applied_to(id_gen, col_map, cte_map, inner, &None, context)?,
2053 distinct,
2054 })
2055 }
2056}
2057
2058/// Attempts an efficient outer join, if `on` has equijoin structure.
2059///
2060/// Both `left` and `right` are decorrelated inputs.
2061///
2062/// The first `oa` columns correspond to an outer context: we should do the
2063/// outer join independently for each prefix. In the case that `on` contains
2064/// just some equality tests between columns of `left` and `right` and some
2065/// local predicates, we can employ a relatively simple plan.
2066///
2067/// The last `on_subquery_types.len()` columns correspond to results from
2068/// subqueries defined in the `on` clause - we treat those as theta-join
2069/// conditions that prohibit the use of the simple plan attempted here.
2070fn attempt_outer_equijoin(
2071 left: MirRelationExpr,
2072 right: MirRelationExpr,
2073 on: MirScalarExpr,
2074 on_subquery_types: Vec<SqlColumnType>,
2075 kind: JoinKind,
2076 oa: usize,
2077 id_gen: &mut mz_ore::id_gen::IdGen,
2078 context: &Context,
2079) -> Result<Option<MirRelationExpr>, PlanError> {
2080 // TODO(database-issues#6827): In theory, we can be smarter and also handle `on`
2081 // predicates that reference subqueries as long as these subqueries don't
2082 // reference `left` and `right` at the same time.
2083 //
2084 // TODO(database-issues#6828): This code can be improved as follows:
2085 //
2086 // 1. Move the `canonicalize_predicates(...)` call to `applied_to`.
2087 // 2. Use the canonicalized `on` predicate in the non-equijoin based
2088 // lowering strategy.
2089 // 3. Move the `OnPredicates::new(...)` call to `applied_to`.
2090 // 4. Pass the classified `OnPredicates` as a parameter.
2091 // 5. Guard calls of this function with `on_predicates.is_equijoin()`.
2092 //
2093 // Steps (1 + 2) require further investigation because we might change the
2094 // error semantics in case the `on` predicate contains a literal error..
2095
2096 let l_type = left.typ();
2097 let r_type = right.typ();
2098 let la = l_type.column_types.len() - oa;
2099 let ra = r_type.column_types.len() - oa;
2100 let sa = on_subquery_types.len();
2101
2102 // The output type contains [outer, left, right, sa] attributes.
2103 let mut output_type = Vec::with_capacity(oa + la + ra + sa);
2104 output_type.extend(l_type.column_types);
2105 output_type.extend(r_type.column_types.into_iter().skip(oa));
2106 output_type.extend(on_subquery_types);
2107
2108 // Generally healthy to do, but specifically `USING` conditions sometimes
2109 // put an `AND true` at the end of the `ON` condition.
2110 //
2111 // TODO(aalexandrov): maybe we should already be doing this in `applied_to`.
2112 // However, in that case it's not clear that we won't see regressions if
2113 // `on` simplifies to a literal error.
2114 let mut on = vec![on];
2115 mz_expr::canonicalize::canonicalize_predicates(&mut on, &output_type);
2116
2117 // Form the left and right types without the outer attributes.
2118 output_type.drain(0..oa);
2119 let lt = output_type.drain(0..la).collect_vec();
2120 let rt = output_type.drain(0..ra).collect_vec();
2121 assert!(output_type.len() == sa);
2122
2123 let on_predicates = OnPredicates::new(oa, la, ra, sa, on.clone(), context);
2124 if !on_predicates.is_equijoin(context) {
2125 return Ok(None);
2126 }
2127
2128 // If we've gotten this far, we can do the clever thing.
2129 // We'll want to use left and right multiple times
2130 let result = left.let_in(id_gen, |id_gen, get_left| {
2131 right.let_in(id_gen, |id_gen, get_right| {
2132 // TODO: we know that we can re-use the arrangements of left and right
2133 // needed for the inner join with each of the conditional outer joins.
2134 // It is not clear whether we should hint that, or just let the planner
2135 // and optimizer run and see what happens.
2136
2137 // We'll want the inner join (minus repeated columns)
2138 let join = MirRelationExpr::join(
2139 vec![get_left.clone(), get_right.clone()],
2140 (0..oa).map(|i| vec![(0, i), (1, i)]).collect(),
2141 )
2142 // remove those columns from `right` repeating the first `oa` columns.
2143 .project(
2144 (0..(oa + la))
2145 .chain((oa + la + oa)..(oa + la + oa + ra))
2146 .collect(),
2147 )
2148 // apply the filter constraints here, to ensure nulls are not matched.
2149 .filter(on);
2150
2151 // We'll want to re-use the results of the join multiple times.
2152 join.let_in(id_gen, |id_gen, get_join| {
2153 let mut result = get_join.clone();
2154
2155 // A collection of keys present in both left and right collections.
2156 let join_keys = on_predicates.join_keys();
2157 let both_keys_arity = join_keys.len();
2158 let both_keys = get_join.restrict(join_keys).distinct();
2159
2160 // The plan is now to determine the left and right rows matched in the
2161 // inner join, subtract them from left and right respectively, pad what
2162 // remains with nulls, and fold them in to `result`.
2163
2164 both_keys.let_in(id_gen, |_id_gen, get_both| {
2165 if let JoinKind::LeftOuter { .. } | JoinKind::FullOuter = kind {
2166 // Rows in `left` matched in the inner equijoin. This is
2167 // a semi-join between `left` and `both_keys`.
2168 let left_present = MirRelationExpr::join_scalars(
2169 vec![
2170 get_left
2171 .clone()
2172 // Push local predicates.
2173 .filter(on_predicates.lhs()),
2174 get_both.clone(),
2175 ],
2176 itertools::zip_eq(
2177 on_predicates.eq_lhs(),
2178 (0..both_keys_arity).map(|k| MirScalarExpr::column(oa + la + k)),
2179 )
2180 .map(|(l_key, b_key)| [l_key, b_key].to_vec())
2181 .collect(),
2182 )
2183 .project((0..(oa + la)).collect());
2184
2185 // Determine the types of nulls to use as filler.
2186 let right_fill = rt
2187 .into_iter()
2188 .map(|typ| MirScalarExpr::literal_null(typ.scalar_type))
2189 .collect();
2190
2191 // Add to `result` absent elements, filled with typed nulls.
2192 result = left_present
2193 .negate()
2194 .union(get_left.clone())
2195 .map(right_fill)
2196 .union(result);
2197 }
2198
2199 if let JoinKind::RightOuter | JoinKind::FullOuter = kind {
2200 // Rows in `right` matched in the inner equijoin. This
2201 // is a semi-join between `right` and `both_keys`.
2202 let right_present = MirRelationExpr::join_scalars(
2203 vec![
2204 get_right
2205 .clone()
2206 // Push local predicates.
2207 .filter(on_predicates.rhs()),
2208 get_both,
2209 ],
2210 itertools::zip_eq(
2211 on_predicates.eq_rhs(),
2212 (0..both_keys_arity).map(|k| MirScalarExpr::column(oa + ra + k)),
2213 )
2214 .map(|(r_key, b_key)| [r_key, b_key].to_vec())
2215 .collect(),
2216 )
2217 .project((0..(oa + ra)).collect());
2218
2219 // Determine the types of nulls to use as filler.
2220 let left_fill = lt
2221 .into_iter()
2222 .map(|typ| MirScalarExpr::literal_null(typ.scalar_type))
2223 .collect();
2224
2225 // Add to `result` absent elements, prepended with typed nulls.
2226 result = right_present
2227 .negate()
2228 .union(get_right.clone())
2229 .map(left_fill)
2230 // Permute left fill before right values.
2231 .project(
2232 itertools::chain!(
2233 0..oa, // Preserve `outer`.
2234 oa + ra..oa + la + ra, // Increment the next `la` cols by `ra`.
2235 oa..oa + ra // Decrement the next `ra` cols by `la`.
2236 )
2237 .collect(),
2238 )
2239 .union(result)
2240 }
2241
2242 Ok::<_, PlanError>(result)
2243 })
2244 })
2245 })
2246 })?;
2247 Ok(Some(result))
2248}
2249
2250/// A struct that represents the predicates in the `on` clause in a form
2251/// suitable for efficient planning outer joins with equijoin predicates.
2252struct OnPredicates {
2253 /// A store for classified `ON` predicates.
2254 ///
2255 /// Predicates that reference a single side are adjusted to assume an
2256 /// `outer × <side>` schema.
2257 predicates: Vec<OnPredicate>,
2258 /// Number of outer context columns.
2259 oa: usize,
2260}
2261
2262impl OnPredicates {
2263 const I_OUT: usize = 0; // outer context input position
2264 const I_LHS: usize = 1; // lhs input position
2265 const I_RHS: usize = 2; // rhs input position
2266 const I_SUB: usize = 3; // on subqueries input position
2267
2268 /// Classify the predicates in the `on` clause of an outer join.
2269 ///
2270 /// The other parameters are arities of the input parts:
2271 ///
2272 /// - `oa` is the arity of the `outer` context.
2273 /// - `la` is the arity of the `left` input.
2274 /// - `ra` is the arity of the `right` input.
2275 /// - `sa` is the arity of the `on` subqueries.
2276 ///
2277 /// The constructor assumes that:
2278 ///
2279 /// 1. The `on` parameter will be applied on a result that has the following
2280 /// schema `outer × left × right × on_subqueries`.
2281 /// 2. The `on` parameter is already adjusted to assume that schema.
2282 /// 3. The `on` parameter is obtained by canonicalizing the original `on:
2283 /// MirScalarExpr` with `canonicalize_predicates`.
2284 fn new(
2285 oa: usize,
2286 la: usize,
2287 ra: usize,
2288 sa: usize,
2289 on: Vec<MirScalarExpr>,
2290 _context: &Context,
2291 ) -> Self {
2292 use mz_expr::BinaryFunc::Eq;
2293
2294 // Re-bind those locally for more compact pattern matching.
2295 const I_LHS: usize = OnPredicates::I_LHS;
2296 const I_RHS: usize = OnPredicates::I_RHS;
2297
2298 // Self parameters.
2299 let mut predicates = Vec::with_capacity(on.len());
2300
2301 // Helpers for populating `predicates`.
2302 let inner_join_mapper = mz_expr::JoinInputMapper::new_from_input_arities([oa, la, ra, sa]);
2303 let rhs_permutation = itertools::chain!(0..oa + la, oa..oa + ra).collect::<Vec<_>>();
2304 let lookup_inputs = |expr: &MirScalarExpr| -> Vec<usize> {
2305 inner_join_mapper
2306 .lookup_inputs(expr)
2307 .filter(|&i| i != Self::I_OUT)
2308 .collect()
2309 };
2310 let has_subquery_refs = |expr: &MirScalarExpr| -> bool {
2311 inner_join_mapper
2312 .lookup_inputs(expr)
2313 .any(|i| i == Self::I_SUB)
2314 };
2315
2316 // Iterate over `on` elements and populate `predicates`.
2317 for mut predicate in on {
2318 if predicate.might_error() {
2319 tracing::debug!(case = "thetajoin (error)", "OnPredicates::new");
2320 // Treat predicates that can produce a literal error as Theta.
2321 predicates.push(OnPredicate::Theta(predicate));
2322 } else if has_subquery_refs(&predicate) {
2323 tracing::debug!(case = "thetajoin (subquery)", "OnPredicates::new");
2324 // Treat predicates referencing an `on` subquery as Theta.
2325 predicates.push(OnPredicate::Theta(predicate));
2326 } else if let MirScalarExpr::CallBinary {
2327 func: Eq(_),
2328 expr1,
2329 expr2,
2330 } = &mut predicate
2331 {
2332 // Obtain the non-outer inputs referenced by each side.
2333 let inputs1 = lookup_inputs(expr1);
2334 let inputs2 = lookup_inputs(expr2);
2335
2336 match (&inputs1[..], &inputs2[..]) {
2337 // Neither side references an input. This could be a
2338 // constant expression or an expression that depends only on
2339 // the outer context.
2340 ([], []) => {
2341 predicates.push(OnPredicate::Const(predicate));
2342 }
2343 // Both sides reference different inputs.
2344 ([I_LHS], [I_RHS]) => {
2345 let lhs = expr1.take();
2346 let mut rhs = expr2.take();
2347 rhs.permute(&rhs_permutation);
2348 predicates.push(OnPredicate::Eq(lhs.clone(), rhs.clone()));
2349 predicates.push(OnPredicate::LhsConsequence(lhs.call_is_null().not()));
2350 predicates.push(OnPredicate::RhsConsequence(rhs.call_is_null().not()));
2351 }
2352 // Both sides reference different inputs (swapped).
2353 ([I_RHS], [I_LHS]) => {
2354 let lhs = expr2.take();
2355 let mut rhs = expr1.take();
2356 rhs.permute(&rhs_permutation);
2357 predicates.push(OnPredicate::Eq(lhs.clone(), rhs.clone()));
2358 predicates.push(OnPredicate::LhsConsequence(lhs.call_is_null().not()));
2359 predicates.push(OnPredicate::RhsConsequence(rhs.call_is_null().not()));
2360 }
2361 // Both sides reference the left input or no input.
2362 ([I_LHS], [I_LHS]) | ([I_LHS], []) | ([], [I_LHS]) => {
2363 predicates.push(OnPredicate::Lhs(predicate));
2364 }
2365 // Both sides reference the right input or no input.
2366 ([I_RHS], [I_RHS]) | ([I_RHS], []) | ([], [I_RHS]) => {
2367 predicate.permute(&rhs_permutation);
2368 predicates.push(OnPredicate::Rhs(predicate));
2369 }
2370 // At least one side references more than one input.
2371 _ => {
2372 tracing::debug!(case = "thetajoin (eq)", "OnPredicates::new");
2373 predicates.push(OnPredicate::Theta(predicate));
2374 }
2375 }
2376 } else {
2377 // Obtain the non-outer inputs referenced by this predicate.
2378 let inputs = lookup_inputs(&predicate);
2379
2380 match &inputs[..] {
2381 // The predicate references no inputs. This could be a
2382 // constant expression or an expression that depends only on
2383 // the outer context.
2384 [] => {
2385 predicates.push(OnPredicate::Const(predicate));
2386 }
2387 // The predicate references only the left input.
2388 [I_LHS] => {
2389 predicates.push(OnPredicate::Lhs(predicate));
2390 }
2391 // The predicate references only the right input.
2392 [I_RHS] => {
2393 predicate.permute(&rhs_permutation);
2394 predicates.push(OnPredicate::Rhs(predicate));
2395 }
2396 // The predicate references both inputs.
2397 _ => {
2398 tracing::debug!(case = "thetajoin (non-eq)", "OnPredicates::new");
2399 predicates.push(OnPredicate::Theta(predicate));
2400 }
2401 }
2402 }
2403 }
2404
2405 Self { predicates, oa }
2406 }
2407
2408 /// Check if the predicates can be lowered with an equijoin-based strategy.
2409 fn is_equijoin(&self, context: &Context) -> bool {
2410 // Count each `OnPredicate` variant in `self.predicates`.
2411 let (const_cnt, lhs_cnt, rhs_cnt, eq_cnt, eq_cols, theta_cnt) =
2412 self.predicates.iter().fold(
2413 (0, 0, 0, 0, 0, 0),
2414 |(const_cnt, lhs_cnt, rhs_cnt, eq_cnt, eq_cols, theta_cnt), p| {
2415 (
2416 const_cnt + usize::from(matches!(p, OnPredicate::Const(..))),
2417 lhs_cnt + usize::from(matches!(p, OnPredicate::Lhs(..))),
2418 rhs_cnt + usize::from(matches!(p, OnPredicate::Rhs(..))),
2419 eq_cnt + usize::from(matches!(p, OnPredicate::Eq(..))),
2420 eq_cols + usize::from(matches!(p, OnPredicate::Eq(lhs, rhs) if lhs.is_column() && rhs.is_column())),
2421 theta_cnt + usize::from(matches!(p, OnPredicate::Theta(..))),
2422 )
2423 },
2424 );
2425
2426 let is_equijion = if context.config.enable_new_outer_join_lowering {
2427 // New classifier.
2428 eq_cnt > 0 && theta_cnt == 0
2429 } else {
2430 // Old classifier.
2431 eq_cnt > 0 && eq_cnt == eq_cols && theta_cnt + const_cnt + lhs_cnt + rhs_cnt == 0
2432 };
2433
2434 // Log an entry only if this is an equijoin according to the new classifier.
2435 if eq_cnt > 0 && theta_cnt == 0 {
2436 tracing::debug!(
2437 const_cnt,
2438 lhs_cnt,
2439 rhs_cnt,
2440 eq_cnt,
2441 eq_cols,
2442 theta_cnt,
2443 "OnPredicates::is_equijoin"
2444 );
2445 }
2446
2447 is_equijion
2448 }
2449
2450 /// Return an [`MirRelationExpr`] list that represents the keys for the
2451 /// equijoin. The list will contain the outer columns as a prefix.
2452 fn join_keys(&self) -> JoinKeys {
2453 // We could return either the `lhs` or the `rhs` of the keys used to
2454 // form the inner join as they are equated by the join condition.
2455 let join_keys = self.eq_lhs().collect::<Vec<_>>();
2456
2457 if join_keys.iter().all(|k| k.is_column()) {
2458 tracing::debug!(case = "outputs", "OnPredicates::join_keys");
2459 JoinKeys::Outputs(join_keys.iter().flat_map(|k| k.as_column()).collect())
2460 } else {
2461 tracing::debug!(case = "scalars", "OnPredicates::join_keys");
2462 JoinKeys::Scalars(join_keys)
2463 }
2464 }
2465
2466 /// Return an iterator over the left-hand sides of all [`OnPredicate::Eq`]
2467 /// conditions in the predicates list.
2468 ///
2469 /// The iterator will start with column references to the outer columns as a
2470 /// prefix.
2471 fn eq_lhs(&self) -> impl Iterator<Item = MirScalarExpr> + '_ {
2472 itertools::chain(
2473 (0..self.oa).map(MirScalarExpr::column),
2474 self.predicates.iter().filter_map(|e| match e {
2475 OnPredicate::Eq(lhs, _) => Some(lhs.clone()),
2476 _ => None,
2477 }),
2478 )
2479 }
2480
2481 /// Return an iterator over the right-hand sides of all [`OnPredicate::Eq`]
2482 /// conditions in the predicates list.
2483 ///
2484 /// The iterator will start with column references to the outer columns as a
2485 /// prefix.
2486 fn eq_rhs(&self) -> impl Iterator<Item = MirScalarExpr> + '_ {
2487 itertools::chain(
2488 (0..self.oa).map(MirScalarExpr::column),
2489 self.predicates.iter().filter_map(|e| match e {
2490 OnPredicate::Eq(_, rhs) => Some(rhs.clone()),
2491 _ => None,
2492 }),
2493 )
2494 }
2495
2496 /// Return an iterator over the [`OnPredicate::Lhs`], [`OnPredicate::LhsConsequence`] and
2497 /// [`OnPredicate::Const`] conditions in the predicates list.
2498 fn lhs(&self) -> impl Iterator<Item = MirScalarExpr> + '_ {
2499 self.predicates.iter().filter_map(|p| match p {
2500 // We treat Const predicates local to both inputs.
2501 OnPredicate::Const(p) => Some(p.clone()),
2502 OnPredicate::Lhs(p) => Some(p.clone()),
2503 OnPredicate::LhsConsequence(p) => Some(p.clone()),
2504 _ => None,
2505 })
2506 }
2507
2508 /// Return an iterator over the [`OnPredicate::Rhs`], [`OnPredicate::RhsConsequence`] and
2509 /// [`OnPredicate::Const`] conditions in the predicates list.
2510 fn rhs(&self) -> impl Iterator<Item = MirScalarExpr> + '_ {
2511 self.predicates.iter().filter_map(|p| match p {
2512 // We treat Const predicates local to both inputs.
2513 OnPredicate::Const(p) => Some(p.clone()),
2514 OnPredicate::Rhs(p) => Some(p.clone()),
2515 OnPredicate::RhsConsequence(p) => Some(p.clone()),
2516 _ => None,
2517 })
2518 }
2519}
2520
2521enum OnPredicate {
2522 // A predicate that is either constant or references only outer columns.
2523 Const(MirScalarExpr),
2524 // A local predicate on the left-hand side of the join, i.e., it references only the left input
2525 // and possibly outer columns.
2526 //
2527 // This is one of the original predicates from the ON clause.
2528 //
2529 // One _must_ apply this predicate.
2530 Lhs(MirScalarExpr),
2531 // A local predicate on the left-hand side of the join, i.e., it references only the left input
2532 // and possibly outer columns.
2533 //
2534 // This is not one of the original predicates from the ON clause, but is just a consequence
2535 // of an original predicate in the ON clause, where the original predicate references both
2536 // inputs, but the consequence references only the left input.
2537 //
2538 // For example, the original predicate `input1.x = input2.a` has the consequence
2539 // `input1.x IS NOT NULL`. Applying such a consequence before the input is fed into the join
2540 // prevents null skew, and also makes more CSE opportunities available when the left input's key
2541 // doesn't have a NOT NULL constraint, saving us an arrangement.
2542 //
2543 // Applying the predicate is optional, because the original predicate will be applied anyway.
2544 LhsConsequence(MirScalarExpr),
2545 // A local predicate on the right-hand side of the join.
2546 //
2547 // This is one of the original predicates from the ON clause.
2548 //
2549 // One _must_ apply this predicate.
2550 Rhs(MirScalarExpr),
2551 // A consequence of an original ON predicate, see above.
2552 RhsConsequence(MirScalarExpr),
2553 // An equality predicate between the two sides.
2554 Eq(MirScalarExpr, MirScalarExpr),
2555 // a non-equality predicate between the two sides.
2556 #[allow(dead_code)]
2557 Theta(MirScalarExpr),
2558}
2559
2560/// A set of join keys referencing an input.
2561///
2562/// This is used in the [`MirRelationExpr::Join`] lowering code in order to
2563/// avoid changes (and thereby possible regressions) in plans that have equijoin
2564/// predicates consisting only of column refs.
2565///
2566/// If we were running `CanonicalizeMfp` as part of `NormalizeOps` we might be
2567/// able to get rid of this code, but as it stands `Map` simplification seems
2568/// more cumbersome than `Project` simplification, so do this just to be sure.
2569enum JoinKeys {
2570 // A predicate that is either constant or references only outer columns.
2571 Outputs(Vec<usize>),
2572 // A local predicate on the left-hand side of the join.
2573 Scalars(Vec<MirScalarExpr>),
2574}
2575
2576impl JoinKeys {
2577 fn len(&self) -> usize {
2578 match self {
2579 JoinKeys::Outputs(outputs) => outputs.len(),
2580 JoinKeys::Scalars(scalars) => scalars.len(),
2581 }
2582 }
2583}
2584
2585/// Extension methods for [`MirRelationExpr`] required in the HIR ⇒ MIR lowering
2586/// code.
2587trait LoweringExt {
2588 /// See [`MirRelationExpr::restrict`].
2589 fn restrict(self, join_keys: JoinKeys) -> Self;
2590}
2591
2592impl LoweringExt for MirRelationExpr {
2593 /// Restrict the set of columns of an input to the sequence of [`JoinKeys`].
2594 fn restrict(self, join_keys: JoinKeys) -> Self {
2595 let num_keys = join_keys.len();
2596 match join_keys {
2597 JoinKeys::Outputs(outputs) => self.project(outputs),
2598 JoinKeys::Scalars(scalars) => {
2599 let input_arity = self.arity();
2600 let outputs = (input_arity..input_arity + num_keys).collect();
2601 self.map(scalars).project(outputs)
2602 }
2603 }
2604 }
2605}