mz_sql/plan/lowering.rs
1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Lowering is the process of transforming a `HirRelationExpr`
11//! into a `MirRelationExpr`.
12//!
13//! The most crucial part of lowering is decorrelation; i.e.: rewriting a
14//! `HirScalarExpr` that may contain subqueries (e.g. `SELECT` or `EXISTS`)
15//! with instances of `MirScalarExpr` that contain none of these.
16//!
17//! Informally, a subquery should be viewed as a query that is executed in
18//! the context of some outer relation, for each row of that relation. The
19//! subqueries often contain references to the columns of the outer
20//! relation.
21//!
22//! The transformation we perform maintains an `outer` relation and then
23//! traverses the relation expression that may contain references to those
24//! outer columns. As subqueries are discovered, the current relation
25//! expression is recast as the outer expression until such a point as the
26//! scalar expression's evaluation can be determined and appended to each
27//! row of the previously outer relation.
28//!
29//! It is important that the outer columns (the initial columns) act as keys
30//! for all nested computation. When counts or other aggregations are
31//! performed, they should include not only the indicated keys but also all
32//! of the outer columns.
33//!
34//! The decorrelation transformation is initialized with an empty outer
35//! relation, but it seems entirely appropriate to decorrelate queries that
36//! contain "holes" from prepared statements, as if the query was a subquery
37//! against a relation containing the assignments of values to those holes.
38
39use std::collections::{BTreeMap, BTreeSet};
40use std::iter::repeat;
41
42use itertools::Itertools;
43use mz_expr::visit::Visit;
44use mz_expr::{AccessStrategy, AggregateFunc, MirRelationExpr, MirScalarExpr};
45use mz_ore::collections::CollectionExt;
46use mz_ore::stack::maybe_grow;
47use mz_repr::*;
48
49use crate::optimizer_metrics::OptimizerMetrics;
50use crate::plan::hir::{
51 AggregateExpr, ColumnOrder, ColumnRef, HirRelationExpr, HirScalarExpr, JoinKind, WindowExprType,
52};
53use crate::plan::{PlanError, transform_hir};
54use crate::session::vars::SystemVars;
55
56mod variadic_left;
57
58/// Maps a leveled column reference to a specific column.
59///
60/// Leveled column references are nested, so that larger levels are
61/// found early in a record and level zero is found at the end.
62///
63/// The column map only stores references for levels greater than zero,
64/// and column references at level zero simply start at the first column
65/// after all prior references.
66#[derive(Debug, Clone)]
67struct ColumnMap {
68 inner: BTreeMap<ColumnRef, usize>,
69}
70
71impl ColumnMap {
72 fn empty() -> ColumnMap {
73 Self::new(BTreeMap::new())
74 }
75
76 fn new(inner: BTreeMap<ColumnRef, usize>) -> ColumnMap {
77 ColumnMap { inner }
78 }
79
80 fn get(&self, col_ref: &ColumnRef) -> usize {
81 if col_ref.level == 0 {
82 self.inner.len() + col_ref.column
83 } else {
84 self.inner[col_ref]
85 }
86 }
87
88 fn len(&self) -> usize {
89 self.inner.len()
90 }
91
92 /// Updates references in the `ColumnMap` for use in a nested scope. The
93 /// provided `arity` must specify the arity of the current scope.
94 fn enter_scope(&self, arity: usize) -> ColumnMap {
95 // From the perspective of the nested scope, all existing column
96 // references will be one level greater.
97 let existing = self
98 .inner
99 .clone()
100 .into_iter()
101 .update(|(col, _i)| col.level += 1);
102
103 // All columns in the current scope become explicit entries in the
104 // immediate parent scope.
105 let new = (0..arity).map(|i| {
106 (
107 ColumnRef {
108 level: 1,
109 column: i,
110 },
111 self.len() + i,
112 )
113 });
114
115 ColumnMap::new(existing.chain(new).collect())
116 }
117}
118
119/// Map with the CTEs currently in scope.
120type CteMap = BTreeMap<mz_expr::LocalId, CteDesc>;
121
122/// Information about needed when finding a reference to a CTE in scope.
123#[derive(Clone)]
124struct CteDesc {
125 /// The new ID assigned to the lowered version of the CTE, which may not match
126 /// the ID of the input CTE.
127 new_id: mz_expr::LocalId,
128 /// The relation type of the CTE including the columns from the outer
129 /// context at the beginning.
130 relation_type: RelationType,
131 /// The outer relation the CTE was applied to.
132 outer_relation: MirRelationExpr,
133}
134
135#[derive(Debug)]
136pub struct Config {
137 /// Enable outer join lowering implemented in database-issues#6747.
138 pub enable_new_outer_join_lowering: bool,
139 /// Enable outer join lowering implemented in database-issues#7561.
140 pub enable_variadic_left_join_lowering: bool,
141}
142
143impl From<&SystemVars> for Config {
144 fn from(vars: &SystemVars) -> Self {
145 Self {
146 enable_new_outer_join_lowering: vars.enable_new_outer_join_lowering(),
147 enable_variadic_left_join_lowering: vars.enable_variadic_left_join_lowering(),
148 }
149 }
150}
151
152/// Context passed to the lowering. This is wired to most parts of the lowering.
153pub(crate) struct Context<'a> {
154 /// Feature flags affecting the behavior of lowering.
155 pub config: &'a Config,
156 /// Optional, because some callers don't have an `OptimizerMetrics` handy. When it's None, we
157 /// simply don't write metrics.
158 pub metrics: Option<&'a OptimizerMetrics>,
159}
160
161impl HirRelationExpr {
162 /// Rewrite `self` into a `MirRelationExpr`.
163 /// This requires rewriting all correlated subqueries (nested `HirRelationExpr`s) into flat queries
164 #[mz_ore::instrument(target = "optimizer", level = "trace", name = "hir_to_mir")]
165 pub fn lower<C: Into<Config>>(
166 self,
167 config: C,
168 metrics: Option<&OptimizerMetrics>,
169 ) -> Result<MirRelationExpr, PlanError> {
170 let context = Context {
171 config: &config.into(),
172 metrics,
173 };
174 let result = match self {
175 // We directly rewrite a Constant into the corresponding `MirRelationExpr::Constant`
176 // to ensure that the downstream optimizer can easily bypass most
177 // irrelevant optimizations (e.g. reduce folding) for this expression
178 // without having to re-learn the fact that it is just a constant,
179 // as it would if the constant were wrapped in a Let-Get pair.
180 HirRelationExpr::Constant { rows, typ } => {
181 let rows: Vec<_> = rows.into_iter().map(|row| (row, Diff::ONE)).collect();
182 MirRelationExpr::Constant {
183 rows: Ok(rows),
184 typ,
185 }
186 }
187 mut other => {
188 let mut id_gen = mz_ore::id_gen::IdGen::default();
189 transform_hir::split_subquery_predicates(&mut other)?;
190 transform_hir::try_simplify_quantified_comparisons(&mut other)?;
191 transform_hir::fuse_window_functions(&mut other, &context)?;
192 MirRelationExpr::constant(vec![vec![]], RelationType::new(vec![])).let_in(
193 &mut id_gen,
194 |id_gen, get_outer| {
195 other.applied_to(
196 id_gen,
197 get_outer,
198 &ColumnMap::empty(),
199 &mut CteMap::new(),
200 &context,
201 )
202 },
203 )?
204 }
205 };
206
207 mz_repr::explain::trace_plan(&result);
208
209 Ok(result)
210 }
211
212 /// Return a `MirRelationExpr` which evaluates `self` once for each row of `get_outer`.
213 ///
214 /// For uncorrelated `self`, this should be the cross-product between `get_outer` and `self`.
215 /// When `self` references columns of `get_outer`, much more work needs to occur.
216 ///
217 /// The `col_map` argument contains mappings to some of the columns of `get_outer`, though
218 /// perhaps not all of them. It should be used as the basis of resolving column references,
219 /// but care must be taken when adding new columns that `get_outer.arity()` is where they
220 /// will start, rather than any function of `col_map`.
221 ///
222 /// The `get_outer` expression should be a `Get` with no duplicate rows, describing the distinct
223 /// assignment of values to outer rows.
224 fn applied_to(
225 self,
226 id_gen: &mut mz_ore::id_gen::IdGen,
227 get_outer: MirRelationExpr,
228 col_map: &ColumnMap,
229 cte_map: &mut CteMap,
230 context: &Context,
231 ) -> Result<MirRelationExpr, PlanError> {
232 maybe_grow(|| {
233 use MirRelationExpr as SR;
234
235 use HirRelationExpr::*;
236
237 if let MirRelationExpr::Get { .. } = &get_outer {
238 } else {
239 panic!(
240 "get_outer: expected a MirRelationExpr::Get, found {:?}",
241 get_outer
242 );
243 }
244 assert_eq!(col_map.len(), get_outer.arity());
245 Ok(match self {
246 Constant { rows, typ } => {
247 // Constant expressions are not correlated with `get_outer`, and should be cross-products.
248 get_outer.product(SR::Constant {
249 rows: Ok(rows.into_iter().map(|row| (row, Diff::ONE)).collect()),
250 typ,
251 })
252 }
253 Get { id, typ } => match id {
254 mz_expr::Id::Local(local_id) => {
255 let cte_desc = cte_map.get(&local_id).unwrap();
256 let get_cte = SR::Get {
257 id: mz_expr::Id::Local(cte_desc.new_id.clone()),
258 typ: cte_desc.relation_type.clone(),
259 access_strategy: AccessStrategy::UnknownOrLocal,
260 };
261 if get_outer == cte_desc.outer_relation {
262 // If the CTE was applied to the same exact relation, we can safely
263 // return a `Get` relation.
264 get_cte
265 } else {
266 // Otherwise, the new outer relation may contain more columns from some
267 // intermediate scope placed between the definition of the CTE and this
268 // reference of the CTE and/or more operations applied on top of the
269 // outer relation.
270 //
271 // An example of the latter is the following query:
272 //
273 // SELECT *
274 // FROM x,
275 // LATERAL(WITH a(m) as (SELECT max(y.a) FROM y WHERE y.a < x.a)
276 // SELECT (SELECT m FROM a) FROM y) b;
277 //
278 // When the CTE is lowered, the outer relation is `Get x`. But then,
279 // the reference of the CTE is applied to `Distinct(Join(Get x, Get y), x.*)`
280 // which has the same cardinality as `Get x`.
281 //
282 // In any case, `get_outer` is guaranteed to contain the columns of the
283 // outer relation the CTE was applied to at its prefix. Since, we must
284 // return a relation containing `get_outer`'s column at the beginning,
285 // we must build a join between `get_outer` and `get_cte` on their common
286 // columns.
287 let oa = get_outer.arity();
288 let cte_outer_columns = cte_desc.relation_type.arity() - typ.arity();
289 let equivalences = (0..cte_outer_columns)
290 .map(|pos| {
291 vec![
292 MirScalarExpr::Column(pos),
293 MirScalarExpr::Column(pos + oa),
294 ]
295 })
296 .collect();
297
298 // Project out the second copy of the common between `get_outer` and
299 // `cte_desc.outer_relation`.
300 let projection = (0..oa)
301 .chain(oa + cte_outer_columns..oa + cte_outer_columns + typ.arity())
302 .collect_vec();
303 SR::join_scalars(vec![get_outer, get_cte], equivalences)
304 .project(projection)
305 }
306 }
307 _ => {
308 // Get statements are only to external sources, and are not correlated with `get_outer`.
309 get_outer.product(SR::Get {
310 id,
311 typ,
312 access_strategy: AccessStrategy::UnknownOrLocal,
313 })
314 }
315 },
316 Let {
317 name: _,
318 id,
319 value,
320 body,
321 } => {
322 let value =
323 value.applied_to(id_gen, get_outer.clone(), col_map, cte_map, context)?;
324 value.let_in(id_gen, |id_gen, get_value| {
325 let (new_id, typ) = if let MirRelationExpr::Get {
326 id: mz_expr::Id::Local(id),
327 typ,
328 ..
329 } = get_value
330 {
331 (id, typ)
332 } else {
333 panic!(
334 "get_value: expected a MirRelationExpr::Get with local Id, found {:?}",
335 get_value
336 );
337 };
338 // Add the information about the CTE to the map and remove it when
339 // it goes out of scope.
340 let old_value = cte_map.insert(
341 id.clone(),
342 CteDesc {
343 new_id,
344 relation_type: typ,
345 outer_relation: get_outer.clone(),
346 },
347 );
348 let body = body.applied_to(id_gen, get_outer, col_map, cte_map, context);
349 if let Some(old_value) = old_value {
350 cte_map.insert(id, old_value);
351 } else {
352 cte_map.remove(&id);
353 }
354 body
355 })?
356 }
357 LetRec {
358 limit,
359 bindings,
360 body,
361 } => {
362 let num_bindings = bindings.len();
363
364 // We use the outer type with the HIR types to form MIR CTE types.
365 let outer_column_types = get_outer.typ().column_types;
366
367 // Rename and introduce all bindings.
368 let mut shadowed_bindings = Vec::with_capacity(num_bindings);
369 let mut mir_ids = Vec::with_capacity(num_bindings);
370 for (_name, id, _value, typ) in bindings.iter() {
371 let mir_id = mz_expr::LocalId::new(id_gen.allocate_id());
372 mir_ids.push(mir_id);
373 let shadowed = cte_map.insert(
374 id.clone(),
375 CteDesc {
376 new_id: mir_id,
377 relation_type: RelationType::new(
378 outer_column_types
379 .iter()
380 .cloned()
381 .chain(typ.column_types.iter().cloned())
382 .collect::<Vec<_>>(),
383 ),
384 outer_relation: get_outer.clone(),
385 },
386 );
387 shadowed_bindings.push((*id, shadowed));
388 }
389
390 let mut mir_values = Vec::with_capacity(num_bindings);
391 for (_name, _id, value, _typ) in bindings.into_iter() {
392 mir_values.push(value.applied_to(
393 id_gen,
394 get_outer.clone(),
395 col_map,
396 cte_map,
397 context,
398 )?);
399 }
400
401 let mir_body = body.applied_to(id_gen, get_outer, col_map, cte_map, context)?;
402
403 // Remove our bindings and reinstate any shadowed bindings.
404 for (id, shadowed) in shadowed_bindings {
405 if let Some(shadowed) = shadowed {
406 cte_map.insert(id, shadowed);
407 } else {
408 cte_map.remove(&id);
409 }
410 }
411
412 MirRelationExpr::LetRec {
413 ids: mir_ids,
414 values: mir_values,
415 // Copy the limit to each binding.
416 limits: repeat(limit).take(num_bindings).collect(),
417 body: Box::new(mir_body),
418 }
419 }
420 Project { input, outputs } => {
421 // Projections should be applied to the decorrelated `inner`, and to its columns,
422 // which means rebasing `outputs` to start `get_outer.arity()` columns later.
423 let input =
424 input.applied_to(id_gen, get_outer.clone(), col_map, cte_map, context)?;
425 let outputs = (0..get_outer.arity())
426 .chain(outputs.into_iter().map(|i| get_outer.arity() + i))
427 .collect::<Vec<_>>();
428 input.project(outputs)
429 }
430 Map { input, mut scalars } => {
431 // Scalar expressions may contain correlated subqueries. We must be cautious!
432
433 // We lower scalars in chunks, and must keep track of the
434 // arity of the HIR fragments lowered so far.
435 let mut lowered_arity = input.arity();
436
437 let mut input =
438 input.applied_to(id_gen, get_outer, col_map, cte_map, context)?;
439
440 // Lower subqueries in maximally sized batches, such as no subquery in the current
441 // batch depends on columns from the same batch.
442 // Note that subqueries in this projection may reference columns added by this
443 // Map operator, so we need to ensure these columns exist before lowering the
444 // subquery.
445 while !scalars.is_empty() {
446 let end_idx = scalars
447 .iter_mut()
448 .position(|s| {
449 let mut requires_nonexistent_column = false;
450 #[allow(deprecated)]
451 s.visit_columns(0, &mut |depth, col| {
452 if col.level == depth {
453 requires_nonexistent_column |= col.column >= lowered_arity
454 }
455 });
456 requires_nonexistent_column
457 })
458 .unwrap_or(scalars.len());
459 assert!(
460 end_idx > 0,
461 "a Map expression references itself or a later column; lowered_arity: {}, expressions: {:?}",
462 lowered_arity,
463 scalars
464 );
465
466 lowered_arity = lowered_arity + end_idx;
467 let scalars = scalars.drain(0..end_idx).collect_vec();
468
469 let old_arity = input.arity();
470 let (with_subqueries, subquery_map) = HirScalarExpr::lower_subqueries(
471 &scalars, id_gen, col_map, cte_map, input, context,
472 )?;
473 input = with_subqueries;
474
475 // We will proceed sequentially through the scalar expressions, for each transforming
476 // the decorrelated `input` into a relation with potentially more columns capable of
477 // addressing the needs of the scalar expression.
478 // Having done so, we add the scalar value of interest and trim off any other newly
479 // added columns.
480 //
481 // The sequential traversal is present as expressions are allowed to depend on the
482 // values of prior expressions.
483 let mut scalar_columns = Vec::new();
484 for scalar in scalars {
485 let scalar = scalar.applied_to(
486 id_gen,
487 col_map,
488 cte_map,
489 &mut input,
490 &Some(&subquery_map),
491 context,
492 )?;
493 input = input.map_one(scalar);
494 scalar_columns.push(input.arity() - 1);
495 }
496
497 // Discard any new columns added by the lowering of the scalar expressions
498 input = input.project((0..old_arity).chain(scalar_columns).collect());
499 }
500
501 input
502 }
503 CallTable { func, exprs } => {
504 // FlatMap expressions may contain correlated subqueries. Unlike Map they are not
505 // allowed to refer to the results of previous expressions, and we have a simpler
506 // implementation that appends all relevant columns first, then applies the flatmap
507 // operator to the result, then strips off any columns introduce by subqueries.
508
509 let mut input = get_outer;
510 let old_arity = input.arity();
511
512 let exprs = exprs
513 .into_iter()
514 .map(|e| e.applied_to(id_gen, col_map, cte_map, &mut input, &None, context))
515 .collect::<Result<Vec<_>, _>>()?;
516
517 let new_arity = input.arity();
518 let output_arity = func.output_arity();
519 input = input.flat_map(func, exprs);
520 if old_arity != new_arity {
521 // this means we added some columns to handle subqueries, and now we need to get rid of them
522 input = input.project(
523 (0..old_arity)
524 .chain(new_arity..new_arity + output_arity)
525 .collect(),
526 );
527 }
528 input
529 }
530 Filter { input, predicates } => {
531 // Filter expressions may contain correlated subqueries.
532 // We extend `get_outer` with sufficient values to determine the value of the predicate,
533 // then filter the results, then strip off any columns that were added for this purpose.
534 let mut input =
535 input.applied_to(id_gen, get_outer, col_map, cte_map, context)?;
536 for predicate in predicates {
537 let old_arity = input.arity();
538 let predicate = predicate
539 .applied_to(id_gen, col_map, cte_map, &mut input, &None, context)?;
540 let new_arity = input.arity();
541 input = input.filter(vec![predicate]);
542 if old_arity != new_arity {
543 // this means we added some columns to handle subqueries, and now we need to get rid of them
544 input = input.project((0..old_arity).collect());
545 }
546 }
547 input
548 }
549 Join {
550 left,
551 right,
552 on,
553 kind,
554 } if right.is_correlated() => {
555 // A correlated join is a join in which the right expression has
556 // access to the columns in the left expression. It turns out
557 // this is *exactly* our branch operator, plus some additional
558 // null handling in the case of left joins. (Right and full
559 // lateral joins are not permitted.)
560 //
561 // As with normal joins, the `on` predicate may be correlated,
562 // and we treat it as a filter that follows the branch.
563
564 assert!(kind.can_be_correlated());
565
566 let left = left.applied_to(id_gen, get_outer, col_map, cte_map, context)?;
567 left.let_in(id_gen, |id_gen, get_left| {
568 let apply_requires_distinct_outer = false;
569 let mut join = branch(
570 id_gen,
571 get_left.clone(),
572 col_map,
573 cte_map,
574 *right,
575 apply_requires_distinct_outer,
576 context,
577 |id_gen, right, get_left, col_map, cte_map, context| {
578 right.applied_to(id_gen, get_left, col_map, cte_map, context)
579 },
580 )?;
581
582 // Plan the `on` predicate.
583 let old_arity = join.arity();
584 let on =
585 on.applied_to(id_gen, col_map, cte_map, &mut join, &None, context)?;
586 join = join.filter(vec![on]);
587 let new_arity = join.arity();
588 if old_arity != new_arity {
589 // This means we added some columns to handle
590 // subqueries, and now we need to get rid of them.
591 join = join.project((0..old_arity).collect());
592 }
593
594 // If a left join, reintroduce any rows from the left that
595 // are missing, with nulls filled in for the right columns.
596 if let JoinKind::LeftOuter { .. } = kind {
597 let default = join
598 .typ()
599 .column_types
600 .into_iter()
601 .skip(get_left.arity())
602 .map(|typ| (Datum::Null, typ.scalar_type))
603 .collect();
604 get_left.lookup(id_gen, join, default)
605 } else {
606 Ok::<_, PlanError>(join)
607 }
608 })?
609 }
610 Join {
611 left,
612 right,
613 on,
614 kind,
615 } => {
616 if context.config.enable_variadic_left_join_lowering {
617 // Attempt to extract a stack of left joins.
618 if let JoinKind::LeftOuter = kind {
619 let mut rights = vec![(&*right, &on)];
620 let mut left_test = &left;
621 while let Join {
622 left,
623 right,
624 on,
625 kind: JoinKind::LeftOuter,
626 } = &**left_test
627 {
628 rights.push((&**right, on));
629 left_test = left;
630 }
631 if rights.len() > 1 {
632 // Defensively clone `cte_map` as it may be mutated.
633 let cte_map_clone = cte_map.clone();
634 if let Ok(Some(magic)) = variadic_left::attempt_left_join_magic(
635 left_test,
636 rights,
637 id_gen,
638 get_outer.clone(),
639 col_map,
640 cte_map,
641 context,
642 ) {
643 return Ok(magic);
644 } else {
645 cte_map.clone_from(&cte_map_clone);
646 }
647 }
648 }
649 }
650
651 // Both join expressions should be decorrelated, and then joined by their
652 // leading columns to form only those pairs corresponding to the same row
653 // of `get_outer`.
654 //
655 // The `on` predicate may contain correlated subqueries, and we treat it
656 // as though it was a filter, with the caveat that we also translate outer
657 // joins in this step. The post-filtration results need to be considered
658 // against the records present in the left and right (decorrelated) inputs,
659 // depending on the type of join.
660 let oa = get_outer.arity();
661 let left =
662 left.applied_to(id_gen, get_outer.clone(), col_map, cte_map, context)?;
663 let lt = left.typ().column_types.into_iter().skip(oa).collect_vec();
664 let la = lt.len();
665 left.let_in(id_gen, |id_gen, get_left| {
666 let right_col_map = col_map.enter_scope(0);
667 let right = right.applied_to(
668 id_gen,
669 get_outer.clone(),
670 &right_col_map,
671 cte_map,
672 context,
673 )?;
674 let rt = right.typ().column_types.into_iter().skip(oa).collect_vec();
675 let ra = rt.len();
676 right.let_in(id_gen, |id_gen, get_right| {
677 let mut product = SR::join(
678 vec![get_left.clone(), get_right.clone()],
679 (0..oa).map(|i| vec![(0, i), (1, i)]).collect(),
680 )
681 // Project away the repeated copy of get_outer's columns.
682 .project(
683 (0..(oa + la))
684 .chain((oa + la + oa)..(oa + la + oa + ra))
685 .collect(),
686 );
687
688 // Decorrelate and lower the `on` clause.
689 let on = on.applied_to(
690 id_gen,
691 col_map,
692 cte_map,
693 &mut product,
694 &None,
695 context,
696 )?;
697 // Collect the types of all subqueries appearing in
698 // the `on` clause. The subquery results were
699 // appended to `product` in the `on.applied_to(...)`
700 // call above.
701 let on_subquery_types = product
702 .typ()
703 .column_types
704 .drain(oa + la + ra..)
705 .collect_vec();
706 // Remember if `on` had any subqueries.
707 let on_has_subqueries = !on_subquery_types.is_empty();
708
709 // Attempt an efficient equijoin implementation, in which outer joins are
710 // more efficiently rendered than in general. This can return `None` if
711 // such a plan is not possible, for example if `on` does not describe an
712 // equijoin between columns of `left` and `right`.
713 if kind != JoinKind::Inner {
714 if let Some(joined) = attempt_outer_equijoin(
715 get_left.clone(),
716 get_right.clone(),
717 on.clone(),
718 on_subquery_types,
719 kind.clone(),
720 oa,
721 id_gen,
722 context,
723 )? {
724 if let Some(metrics) = context.metrics {
725 metrics.inc_outer_join_lowering("equi");
726 }
727 return Ok(joined);
728 }
729 }
730
731 // Otherwise, perform a more general join.
732 if let Some(metrics) = context.metrics {
733 metrics.inc_outer_join_lowering("general");
734 }
735 let mut join = product.filter(vec![on]);
736 if on_has_subqueries {
737 // This means that `on.applied_to(...)` appended
738 // some columns to handle subqueries, and now we
739 // need to get rid of them.
740 join = join.project((0..oa + la + ra).collect());
741 }
742 join.let_in(id_gen, |id_gen, get_join| {
743 let mut result = get_join.clone();
744 if let JoinKind::LeftOuter { .. } | JoinKind::FullOuter { .. } =
745 kind
746 {
747 let left_outer = get_left.clone().anti_lookup::<PlanError>(
748 id_gen,
749 get_join.clone(),
750 rt.into_iter()
751 .map(|typ| (Datum::Null, typ.scalar_type))
752 .collect(),
753 )?;
754 result = result.union(left_outer);
755 }
756 if let JoinKind::RightOuter | JoinKind::FullOuter = kind {
757 let right_outer = get_right
758 .clone()
759 .anti_lookup::<PlanError>(
760 id_gen,
761 get_join
762 // need to swap left and right to make the anti_lookup work
763 .project(
764 (0..oa)
765 .chain((oa + la)..(oa + la + ra))
766 .chain((oa)..(oa + la))
767 .collect(),
768 ),
769 lt.into_iter()
770 .map(|typ| (Datum::Null, typ.scalar_type))
771 .collect(),
772 )?
773 // swap left and right back again
774 .project(
775 (0..oa)
776 .chain((oa + ra)..(oa + ra + la))
777 .chain((oa)..(oa + ra))
778 .collect(),
779 );
780 result = result.union(right_outer);
781 }
782 Ok::<MirRelationExpr, PlanError>(result)
783 })
784 })
785 })?
786 }
787 Union { base, inputs } => {
788 // Union is uncomplicated.
789 SR::Union {
790 base: Box::new(base.applied_to(
791 id_gen,
792 get_outer.clone(),
793 col_map,
794 cte_map,
795 context,
796 )?),
797 inputs: inputs
798 .into_iter()
799 .map(|input| {
800 input.applied_to(
801 id_gen,
802 get_outer.clone(),
803 col_map,
804 cte_map,
805 context,
806 )
807 })
808 .collect::<Result<Vec<_>, _>>()?,
809 }
810 }
811 Reduce {
812 input,
813 group_key,
814 aggregates,
815 expected_group_size,
816 } => {
817 // Reduce may contain expressions with correlated subqueries.
818 // In addition, here an empty reduction key signifies that we need to supply default values
819 // in the case that there are no results (as in a SQL aggregation without an explicit GROUP BY).
820 let mut input =
821 input.applied_to(id_gen, get_outer.clone(), col_map, cte_map, context)?;
822 let applied_group_key = (0..get_outer.arity())
823 .chain(group_key.iter().map(|i| get_outer.arity() + i))
824 .collect();
825 let applied_aggregates = aggregates
826 .into_iter()
827 .map(|aggregate| {
828 aggregate.applied_to(id_gen, col_map, cte_map, &mut input, context)
829 })
830 .collect::<Result<Vec<_>, _>>()?;
831 let input_type = input.typ();
832 let default = applied_aggregates
833 .iter()
834 .map(|agg| {
835 (
836 agg.func.default(),
837 agg.typ(&input_type.column_types).scalar_type,
838 )
839 })
840 .collect();
841 // NOTE we don't need to remove any extra columns from aggregate.applied_to above because the reduce will do that anyway
842 let mut reduced =
843 input.reduce(applied_group_key, applied_aggregates, expected_group_size);
844
845 // Introduce default values in the case the group key is empty.
846 if group_key.is_empty() {
847 reduced = get_outer.lookup::<PlanError>(id_gen, reduced, default)?;
848 }
849 reduced
850 }
851 Distinct { input } => {
852 // Distinct is uncomplicated.
853 input
854 .applied_to(id_gen, get_outer, col_map, cte_map, context)?
855 .distinct()
856 }
857 TopK {
858 input,
859 group_key,
860 order_key,
861 limit,
862 offset,
863 expected_group_size,
864 } => {
865 // TopK is uncomplicated, except that we must group by the columns of `get_outer` as well.
866 let mut input =
867 input.applied_to(id_gen, get_outer.clone(), col_map, cte_map, context)?;
868 let mut applied_group_key: Vec<_> = (0..get_outer.arity())
869 .chain(group_key.iter().map(|i| get_outer.arity() + i))
870 .collect();
871 let applied_order_key = order_key
872 .iter()
873 .map(|column_order| ColumnOrder {
874 column: column_order.column + get_outer.arity(),
875 desc: column_order.desc,
876 nulls_last: column_order.nulls_last,
877 })
878 .collect();
879
880 let old_arity = input.arity();
881
882 // Lower `limit`, which may introduce new columns if is a correlated subquery.
883 let mut limit_mir = None;
884 if let Some(limit) = limit {
885 limit_mir = Some(
886 limit
887 .applied_to(id_gen, col_map, cte_map, &mut input, &None, context)?,
888 );
889 }
890
891 let new_arity = input.arity();
892 // Extend the key to contain any new columns.
893 applied_group_key.extend(old_arity..new_arity);
894
895 let mut result = input.top_k(
896 applied_group_key,
897 applied_order_key,
898 limit_mir,
899 offset,
900 expected_group_size,
901 );
902
903 // If new columns were added for `limit` we must remove them.
904 if old_arity != new_arity {
905 result = result.project((0..old_arity).collect());
906 }
907
908 result
909 }
910 Negate { input } => {
911 // Negate is uncomplicated.
912 input
913 .applied_to(id_gen, get_outer, col_map, cte_map, context)?
914 .negate()
915 }
916 Threshold { input } => {
917 // Threshold is uncomplicated.
918 input
919 .applied_to(id_gen, get_outer, col_map, cte_map, context)?
920 .threshold()
921 }
922 })
923 })
924 }
925}
926
927impl HirScalarExpr {
928 /// Rewrite `self` into a `mz_expr::ScalarExpr` which can be applied to the modified `inner`.
929 ///
930 /// This method is responsible for decorrelating subqueries in `self` by introducing further columns
931 /// to `inner`, and rewriting `self` to refer to its physical columns (specified by `usize` positions).
932 /// The most complicated logic is for the scalar expressions that involve subqueries, each of which are
933 /// documented in more detail closer to their logic.
934 ///
935 /// This process presumes that `inner` is the result of decorrelation, meaning its first several columns
936 /// may be inherited from outer relations. The `col_map` column map should provide specific offsets where
937 /// each of these references can be found.
938 fn applied_to(
939 self,
940 id_gen: &mut mz_ore::id_gen::IdGen,
941 col_map: &ColumnMap,
942 cte_map: &mut CteMap,
943 inner: &mut MirRelationExpr,
944 subquery_map: &Option<&BTreeMap<HirScalarExpr, usize>>,
945 context: &Context,
946 ) -> Result<MirScalarExpr, PlanError> {
947 maybe_grow(|| {
948 use MirScalarExpr as SS;
949
950 use HirScalarExpr::*;
951
952 if let Some(subquery_map) = subquery_map {
953 if let Some(col) = subquery_map.get(&self) {
954 return Ok(SS::Column(*col));
955 }
956 }
957
958 Ok::<MirScalarExpr, PlanError>(match self {
959 Column(col_ref) => SS::Column(col_map.get(&col_ref)),
960 Literal(row, typ) => SS::Literal(Ok(row), typ),
961 Parameter(_) => panic!("cannot decorrelate expression with unbound parameters"),
962 CallUnmaterializable(func) => SS::CallUnmaterializable(func),
963 CallUnary { func, expr } => SS::CallUnary {
964 func,
965 expr: Box::new(expr.applied_to(
966 id_gen,
967 col_map,
968 cte_map,
969 inner,
970 subquery_map,
971 context,
972 )?),
973 },
974 CallBinary { func, expr1, expr2 } => SS::CallBinary {
975 func,
976 expr1: Box::new(expr1.applied_to(
977 id_gen,
978 col_map,
979 cte_map,
980 inner,
981 subquery_map,
982 context,
983 )?),
984 expr2: Box::new(expr2.applied_to(
985 id_gen,
986 col_map,
987 cte_map,
988 inner,
989 subquery_map,
990 context,
991 )?),
992 },
993 CallVariadic { func, exprs } => SS::CallVariadic {
994 func,
995 exprs: exprs
996 .into_iter()
997 .map(|expr| {
998 expr.applied_to(id_gen, col_map, cte_map, inner, subquery_map, context)
999 })
1000 .collect::<Result<Vec<_>, _>>()?,
1001 },
1002 If { cond, then, els } => {
1003 // The `If` case is complicated by the fact that we do not want to
1004 // apply the `then` or `else` logic to tuples that respectively do
1005 // not or do pass the `cond` test. Our strategy is to independently
1006 // decorrelate the `then` and `else` logic, and apply each to tuples
1007 // that respectively pass and do not pass the `cond` logic (which is
1008 // executed, and so decorrelated, for all tuples).
1009 //
1010 // Informally, we turn the `if` statement into:
1011 //
1012 // let then_case = inner.filter(cond).map(then);
1013 // let else_case = inner.filter(!cond).map(else);
1014 // return then_case.concat(else_case);
1015 //
1016 // We only require this if either expression would result in any
1017 // computation beyond the expr itself, which we will interpret as
1018 // "introduces additional columns". In the absence of correlation,
1019 // we should just retain a `ScalarExpr::If` expression; the inverse
1020 // transformation as above is complicated to recover after the fact,
1021 // and we would benefit from not introducing the complexity.
1022
1023 let inner_arity = inner.arity();
1024 let cond_expr =
1025 cond.applied_to(id_gen, col_map, cte_map, inner, subquery_map, context)?;
1026
1027 // Defensive copies, in case we mangle these in decorrelation.
1028 let inner_clone = inner.clone();
1029 let then_clone = then.clone();
1030 let else_clone = els.clone();
1031
1032 let cond_arity = inner.arity();
1033 let then_expr =
1034 then.applied_to(id_gen, col_map, cte_map, inner, subquery_map, context)?;
1035 let else_expr =
1036 els.applied_to(id_gen, col_map, cte_map, inner, subquery_map, context)?;
1037
1038 if cond_arity == inner.arity() {
1039 // If no additional columns were added, we simply return the
1040 // `If` variant with the updated expressions.
1041 SS::If {
1042 cond: Box::new(cond_expr),
1043 then: Box::new(then_expr),
1044 els: Box::new(else_expr),
1045 }
1046 } else {
1047 // If columns were added, we need a more careful approach, as
1048 // described above. First, we need to de-correlate each of
1049 // the two expressions independently, and apply their cases
1050 // as `MirRelationExpr::Map` operations.
1051
1052 *inner = inner_clone.let_in(id_gen, |id_gen, get_inner| {
1053 // Restrict to records satisfying `cond_expr` and apply `then` as a map.
1054 let mut then_inner = get_inner.clone().filter(vec![cond_expr.clone()]);
1055 let then_expr = then_clone.applied_to(
1056 id_gen,
1057 col_map,
1058 cte_map,
1059 &mut then_inner,
1060 subquery_map,
1061 context,
1062 )?;
1063 let then_arity = then_inner.arity();
1064 then_inner = then_inner
1065 .map_one(then_expr)
1066 .project((0..inner_arity).chain(Some(then_arity)).collect());
1067
1068 // Restrict to records not satisfying `cond_expr` and apply `els` as a map.
1069 let mut else_inner = get_inner.filter(vec![SS::CallVariadic {
1070 func: mz_expr::VariadicFunc::Or,
1071 exprs: vec![
1072 cond_expr
1073 .clone()
1074 .call_binary(SS::literal_false(), mz_expr::BinaryFunc::Eq),
1075 cond_expr.clone().call_is_null(),
1076 ],
1077 }]);
1078 let else_expr = else_clone.applied_to(
1079 id_gen,
1080 col_map,
1081 cte_map,
1082 &mut else_inner,
1083 subquery_map,
1084 context,
1085 )?;
1086 let else_arity = else_inner.arity();
1087 else_inner = else_inner
1088 .map_one(else_expr)
1089 .project((0..inner_arity).chain(Some(else_arity)).collect());
1090
1091 // concatenate the two results.
1092 Ok::<MirRelationExpr, PlanError>(then_inner.union(else_inner))
1093 })?;
1094
1095 SS::Column(inner_arity)
1096 }
1097 }
1098
1099 // Subqueries!
1100 // These are surprisingly subtle. Things to be careful of:
1101
1102 // Anything in the subquery that cares about row counts (Reduce/Distinct/Negate/Threshold) must not:
1103 // * change the row counts of the outer query
1104 // * accidentally compute its own value using the row counts of the outer query
1105 // Use `branch` to calculate the subquery once for each __distinct__ key in the outer
1106 // query and then join the answers back on to the original rows of the outer query.
1107
1108 // When the subquery would return 0 rows for some row in the outer query, `subquery.applied_to(get_inner)` will not have any corresponding row.
1109 // Use `lookup` if you need to add default values for cases when the subquery returns 0 rows.
1110 Exists(expr) => {
1111 let apply_requires_distinct_outer = true;
1112 *inner = apply_existential_subquery(
1113 id_gen,
1114 inner.take_dangerous(),
1115 col_map,
1116 cte_map,
1117 *expr,
1118 apply_requires_distinct_outer,
1119 context,
1120 )?;
1121 SS::Column(inner.arity() - 1)
1122 }
1123
1124 Select(expr) => {
1125 let apply_requires_distinct_outer = true;
1126 *inner = apply_scalar_subquery(
1127 id_gen,
1128 inner.take_dangerous(),
1129 col_map,
1130 cte_map,
1131 *expr,
1132 apply_requires_distinct_outer,
1133 context,
1134 )?;
1135 SS::Column(inner.arity() - 1)
1136 }
1137 Windowing(expr) => {
1138 let partition_by = expr.partition_by;
1139 let order_by = expr.order_by;
1140
1141 // argument lowering for scalar window functions
1142 // (We need to specify the & _ in the arguments because of this problem:
1143 // https://users.rust-lang.org/t/the-implementation-of-fnonce-is-not-general-enough/72141/3 )
1144 let scalar_lower_args =
1145 |_id_gen: &mut _,
1146 _col_map: &_,
1147 _cte_map: &mut _,
1148 _get_inner: &mut _,
1149 _subquery_map: &Option<&_>,
1150 order_by_mir: Vec<MirScalarExpr>,
1151 original_row_record,
1152 original_row_record_type: ScalarType| {
1153 let agg_input = MirScalarExpr::CallVariadic {
1154 func: mz_expr::VariadicFunc::ListCreate {
1155 elem_type: original_row_record_type.clone(),
1156 },
1157 exprs: vec![original_row_record],
1158 };
1159 let mut agg_input = vec![agg_input];
1160 agg_input.extend(order_by_mir.clone());
1161 let agg_input = MirScalarExpr::CallVariadic {
1162 func: mz_expr::VariadicFunc::RecordCreate {
1163 field_names: (0..agg_input.len())
1164 .map(|_| ColumnName::from("?column?"))
1165 .collect_vec(),
1166 },
1167 exprs: agg_input,
1168 };
1169 let list_type = ScalarType::List {
1170 element_type: Box::new(original_row_record_type),
1171 custom_id: None,
1172 };
1173 let agg_input_type = ScalarType::Record {
1174 fields: std::iter::once(&list_type)
1175 .map(|t| {
1176 (ColumnName::from("?column?"), t.clone().nullable(false))
1177 })
1178 .collect(),
1179 custom_id: None,
1180 }
1181 .nullable(false);
1182
1183 Ok((agg_input, agg_input_type))
1184 };
1185
1186 // argument lowering for value window functions and aggregate window functions
1187 let value_or_aggr_lower_args = |hir_encoded_args: Box<HirScalarExpr>| {
1188 |id_gen: &mut _,
1189 col_map: &_,
1190 cte_map: &mut _,
1191 get_inner: &mut _,
1192 subquery_map: &Option<&_>,
1193 order_by_mir: Vec<MirScalarExpr>,
1194 original_row_record,
1195 original_row_record_type| {
1196 // Creates [((OriginalRow, EncodedArgs), OrderByExprs...)]
1197
1198 // Compute the encoded args for all rows
1199 let mir_encoded_args = hir_encoded_args.applied_to(
1200 id_gen,
1201 col_map,
1202 cte_map,
1203 get_inner,
1204 subquery_map,
1205 context,
1206 )?;
1207 let mir_encoded_args_type = mir_encoded_args
1208 .typ(&get_inner.typ().column_types)
1209 .scalar_type;
1210
1211 // Build a new record that has two fields:
1212 // 1. the original row in a record
1213 // 2. the encoded args (which can be either a single value, or a record
1214 // if the window function has multiple arguments, such as `lag`)
1215 let fn_input_record_fields: Box<[_]> =
1216 [original_row_record_type, mir_encoded_args_type]
1217 .iter()
1218 .map(|t| {
1219 (ColumnName::from("?column?"), t.clone().nullable(false))
1220 })
1221 .collect();
1222 let fn_input_record = MirScalarExpr::CallVariadic {
1223 func: mz_expr::VariadicFunc::RecordCreate {
1224 field_names: fn_input_record_fields
1225 .iter()
1226 .map(|(n, _)| n.clone())
1227 .collect_vec(),
1228 },
1229 exprs: vec![original_row_record, mir_encoded_args],
1230 };
1231 let fn_input_record_type = ScalarType::Record {
1232 fields: fn_input_record_fields,
1233 custom_id: None,
1234 }
1235 .nullable(false);
1236
1237 // Build a new record with the record above + the ORDER BY exprs
1238 // This follows the standard encoding of ORDER BY exprs used by aggregate functions
1239 let mut agg_input = vec![fn_input_record];
1240 agg_input.extend(order_by_mir.clone());
1241 let agg_input = MirScalarExpr::CallVariadic {
1242 func: mz_expr::VariadicFunc::RecordCreate {
1243 field_names: (0..agg_input.len())
1244 .map(|_| ColumnName::from("?column?"))
1245 .collect_vec(),
1246 },
1247 exprs: agg_input,
1248 };
1249
1250 let agg_input_type = ScalarType::Record {
1251 fields: [(
1252 ColumnName::from("?column?"),
1253 fn_input_record_type.nullable(false),
1254 )]
1255 .into(),
1256 custom_id: None,
1257 }
1258 .nullable(false);
1259
1260 Ok((agg_input, agg_input_type))
1261 }
1262 };
1263
1264 match expr.func {
1265 WindowExprType::Scalar(scalar_window_expr) => {
1266 let mir_aggr_func = scalar_window_expr.into_expr();
1267 Self::window_func_applied_to(
1268 id_gen,
1269 col_map,
1270 cte_map,
1271 inner,
1272 subquery_map,
1273 partition_by,
1274 order_by,
1275 mir_aggr_func,
1276 scalar_lower_args,
1277 context,
1278 )?
1279 }
1280 WindowExprType::Value(value_window_expr) => {
1281 let (hir_encoded_args, mir_aggr_func) = value_window_expr.into_expr();
1282
1283 Self::window_func_applied_to(
1284 id_gen,
1285 col_map,
1286 cte_map,
1287 inner,
1288 subquery_map,
1289 partition_by,
1290 order_by,
1291 mir_aggr_func,
1292 value_or_aggr_lower_args(hir_encoded_args),
1293 context,
1294 )?
1295 }
1296 WindowExprType::Aggregate(aggr_window_expr) => {
1297 let (hir_encoded_args, mir_aggr_func) = aggr_window_expr.into_expr();
1298
1299 Self::window_func_applied_to(
1300 id_gen,
1301 col_map,
1302 cte_map,
1303 inner,
1304 subquery_map,
1305 partition_by,
1306 order_by,
1307 mir_aggr_func,
1308 value_or_aggr_lower_args(hir_encoded_args),
1309 context,
1310 )?
1311 }
1312 }
1313 }
1314 })
1315 })
1316 }
1317
1318 fn window_func_applied_to<F>(
1319 id_gen: &mut mz_ore::id_gen::IdGen,
1320 col_map: &ColumnMap,
1321 cte_map: &mut CteMap,
1322 inner: &mut MirRelationExpr,
1323 subquery_map: &Option<&BTreeMap<HirScalarExpr, usize>>,
1324 partition_by: Vec<HirScalarExpr>,
1325 order_by: Vec<HirScalarExpr>,
1326 mir_aggr_func: AggregateFunc,
1327 lower_args: F,
1328 context: &Context,
1329 ) -> Result<MirScalarExpr, PlanError>
1330 where
1331 F: FnOnce(
1332 &mut mz_ore::id_gen::IdGen,
1333 &ColumnMap,
1334 &mut CteMap,
1335 &mut MirRelationExpr,
1336 &Option<&BTreeMap<HirScalarExpr, usize>>,
1337 Vec<MirScalarExpr>,
1338 MirScalarExpr,
1339 ScalarType,
1340 ) -> Result<(MirScalarExpr, ColumnType), PlanError>,
1341 {
1342 // Example MIRs for a window function (specifically, a window aggregation):
1343 //
1344 // CREATE TABLE t7(x INT, y INT);
1345 //
1346 // explain decorrelated plan for select sum(x*y) over (partition by x+y order by x-y, x/y) from t7;
1347 //
1348 // Decorrelated Plan
1349 // Project (#3)
1350 // Map (#2)
1351 // Project (#3..=#5)
1352 // Map (record_get[0](record_get[1](#2)), record_get[1](record_get[1](#2)), record_get[0](#2))
1353 // FlatMap unnest_list(#1)
1354 // Reduce group_by=[#2] aggregates=[window_agg[sum order_by=[#0 asc nulls_last, #1 asc nulls_last]](row(row(row(#0, #1), (#0 * #1)), (#0 - #1), (#0 / #1)))]
1355 // Map ((#0 + #1))
1356 // CrossJoin
1357 // Constant
1358 // - ()
1359 // Get materialize.public.t7
1360 //
1361 // The same query after optimizations:
1362 //
1363 // explain select sum(x*y) over (partition by x+y order by x-y, x/y) from t7;
1364 //
1365 // Optimized Plan
1366 // Explained Query:
1367 // Project (#2)
1368 // Map (record_get[0](#1))
1369 // FlatMap unnest_list(#0)
1370 // Project (#1)
1371 // Reduce group_by=[(#0 + #1)] aggregates=[window_agg[sum order_by=[#0 asc nulls_last, #1 asc nulls_last]](row(row(row(#0, #1), (#0 * #1)), (#0 - #1), (#0 / #1)))]
1372 // ReadStorage materialize.public.t7
1373 //
1374 // The `row(row(row(...), ...), ...)` stuff means the following:
1375 // `row(row(row(<original row>), <arguments to window function>), <order by values>...)`
1376 // - The <arguments to window function> can be either a single value or itself a
1377 // `row` if there are multiple arguments.
1378 // - The <order by values> are _not_ wrapped in a `row`, even if there are more than one
1379 // ORDER BY columns.
1380 // - The <original row> currently always captures the entire original row. This should
1381 // improve when we make `ProjectionPushdown` smarter, see
1382 // https://github.com/MaterializeInc/database-issues/issues/5090
1383 //
1384 // TODO:
1385 // We should probably introduce some dedicated Datum constructor functions instead of `row`
1386 // to make MIR plans and MIR construction/manipulation code more readable. Additionally, we
1387 // might even introduce dedicated Datum enum variants, so that the rendering code also
1388 // becomes more readable (and possibly slightly more performant).
1389
1390 *inner = inner
1391 .take_dangerous()
1392 .let_in(id_gen, |id_gen, mut get_inner| {
1393 let order_by_mir = order_by
1394 .into_iter()
1395 .map(|o| {
1396 o.applied_to(
1397 id_gen,
1398 col_map,
1399 cte_map,
1400 &mut get_inner,
1401 subquery_map,
1402 context,
1403 )
1404 })
1405 .collect::<Result<Vec<_>, _>>()?;
1406
1407 // Record input arity here so that any group_keys that need to mutate get_inner
1408 // don't add those columns to the aggregate input.
1409 let input_type = get_inner.typ();
1410 let input_arity = input_type.arity();
1411 // The reduction that computes the window function must be keyed on the columns
1412 // from the outer context, plus the expressions in the partition key. The current
1413 // subquery will be 'executed' for every distinct row from the outer context so
1414 // by putting the outer columns in the grouping key we isolate each re-execution.
1415 let mut group_key = col_map
1416 .inner
1417 .iter()
1418 .map(|(_, outer_col)| *outer_col)
1419 .sorted()
1420 .collect_vec();
1421 for p in partition_by {
1422 let key = p.applied_to(
1423 id_gen,
1424 col_map,
1425 cte_map,
1426 &mut get_inner,
1427 subquery_map,
1428 context,
1429 )?;
1430 if let MirScalarExpr::Column(c) = key {
1431 group_key.push(c);
1432 } else {
1433 get_inner = get_inner.map_one(key);
1434 group_key.push(get_inner.arity() - 1);
1435 }
1436 }
1437
1438 get_inner.let_in(id_gen, |id_gen, mut get_inner| {
1439 // Original columns of the relation
1440 let fields: Box<_> = input_type
1441 .column_types
1442 .iter()
1443 .take(input_arity)
1444 .map(|t| (ColumnName::from("?column?"), t.clone()))
1445 .collect();
1446
1447 // Original row made into a record
1448 let original_row_record = MirScalarExpr::CallVariadic {
1449 func: mz_expr::VariadicFunc::RecordCreate {
1450 field_names: fields.iter().map(|(name, _)| name.clone()).collect_vec(),
1451 },
1452 exprs: (0..input_arity).map(MirScalarExpr::Column).collect_vec(),
1453 };
1454 let original_row_record_type = ScalarType::Record {
1455 fields,
1456 custom_id: None,
1457 };
1458
1459 let (agg_input, agg_input_type) = lower_args(
1460 id_gen,
1461 col_map,
1462 cte_map,
1463 &mut get_inner,
1464 subquery_map,
1465 order_by_mir,
1466 original_row_record,
1467 original_row_record_type,
1468 )?;
1469
1470 let aggregate = mz_expr::AggregateExpr {
1471 func: mir_aggr_func,
1472 expr: agg_input,
1473 distinct: false,
1474 };
1475
1476 // Actually call reduce with the window function
1477 // The output of the aggregation function should be a list of tuples that has
1478 // the result in the first position, and the original row in the second position
1479 let mut reduce = get_inner
1480 .reduce(group_key.clone(), vec![aggregate.clone()], None)
1481 .flat_map(
1482 mz_expr::TableFunc::UnnestList {
1483 el_typ: aggregate
1484 .func
1485 .output_type(agg_input_type)
1486 .scalar_type
1487 .unwrap_list_element_type()
1488 .clone(),
1489 },
1490 vec![MirScalarExpr::Column(group_key.len())],
1491 );
1492 let record_col = reduce.arity() - 1;
1493
1494 // Unpack the record output by the window function
1495 for c in 0..input_arity {
1496 reduce = reduce.take_dangerous().map_one(MirScalarExpr::CallUnary {
1497 func: mz_expr::UnaryFunc::RecordGet(mz_expr::func::RecordGet(c)),
1498 expr: Box::new(MirScalarExpr::CallUnary {
1499 func: mz_expr::UnaryFunc::RecordGet(mz_expr::func::RecordGet(1)),
1500 expr: Box::new(MirScalarExpr::Column(record_col)),
1501 }),
1502 });
1503 }
1504
1505 // Append the column with the result of the window function.
1506 reduce = reduce.take_dangerous().map_one(MirScalarExpr::CallUnary {
1507 func: mz_expr::UnaryFunc::RecordGet(mz_expr::func::RecordGet(0)),
1508 expr: Box::new(MirScalarExpr::Column(record_col)),
1509 });
1510
1511 let agg_col = record_col + 1 + input_arity;
1512 Ok::<_, PlanError>(reduce.project((record_col + 1..agg_col + 1).collect_vec()))
1513 })
1514 })?;
1515 Ok(MirScalarExpr::Column(inner.arity() - 1))
1516 }
1517
1518 /// Applies the subqueries in the given list of scalar expressions to every distinct
1519 /// value of the given relation and returns a join of the given relation with all
1520 /// the subqueries found, and the mapping of scalar expressions with columns projected
1521 /// by the returned join that will hold their results.
1522 fn lower_subqueries(
1523 exprs: &[Self],
1524 id_gen: &mut mz_ore::id_gen::IdGen,
1525 col_map: &ColumnMap,
1526 cte_map: &mut CteMap,
1527 inner: MirRelationExpr,
1528 context: &Context,
1529 ) -> Result<(MirRelationExpr, BTreeMap<HirScalarExpr, usize>), PlanError> {
1530 let mut subquery_map = BTreeMap::new();
1531 let output = inner.let_in(id_gen, |id_gen, get_inner| {
1532 let mut subqueries = Vec::new();
1533 let distinct_inner = get_inner.clone().distinct();
1534 for expr in exprs.iter() {
1535 expr.visit_pre_post(
1536 &mut |e| match e {
1537 // For simplicity, subqueries within a conditional statement will be
1538 // lowered when lowering the conditional expression.
1539 HirScalarExpr::If { .. } => Some(vec![]),
1540 _ => None,
1541 },
1542 &mut |e| match e {
1543 HirScalarExpr::Select(expr) => {
1544 let apply_requires_distinct_outer = false;
1545 let subquery = apply_scalar_subquery(
1546 id_gen,
1547 distinct_inner.clone(),
1548 col_map,
1549 cte_map,
1550 (**expr).clone(),
1551 apply_requires_distinct_outer,
1552 context,
1553 )
1554 .unwrap();
1555
1556 subqueries.push((e.clone(), subquery));
1557 }
1558 HirScalarExpr::Exists(expr) => {
1559 let apply_requires_distinct_outer = false;
1560 let subquery = apply_existential_subquery(
1561 id_gen,
1562 distinct_inner.clone(),
1563 col_map,
1564 cte_map,
1565 (**expr).clone(),
1566 apply_requires_distinct_outer,
1567 context,
1568 )
1569 .unwrap();
1570 subqueries.push((e.clone(), subquery));
1571 }
1572 _ => {}
1573 },
1574 )?;
1575 }
1576
1577 if subqueries.is_empty() {
1578 Ok::<MirRelationExpr, PlanError>(get_inner)
1579 } else {
1580 let inner_arity = get_inner.arity();
1581 let mut total_arity = inner_arity;
1582 let mut join_inputs = vec![get_inner];
1583 let mut join_input_arities = vec![inner_arity];
1584 for (expr, subquery) in subqueries.into_iter() {
1585 // Avoid lowering duplicated subqueries
1586 if !subquery_map.contains_key(&expr) {
1587 let subquery_arity = subquery.arity();
1588 assert_eq!(subquery_arity, inner_arity + 1);
1589 join_inputs.push(subquery);
1590 join_input_arities.push(subquery_arity);
1591 total_arity += subquery_arity;
1592
1593 // Column with the value of the subquery
1594 subquery_map.insert(expr, total_arity - 1);
1595 }
1596 }
1597 // Each subquery projects all the columns of the outer context (distinct_inner)
1598 // plus 1 column, containing the result of the subquery. Those columns must be
1599 // joined with the outer/main relation (get_inner).
1600 let input_mapper =
1601 mz_expr::JoinInputMapper::new_from_input_arities(join_input_arities);
1602 let equivalences = (0..inner_arity)
1603 .map(|col| {
1604 join_inputs
1605 .iter()
1606 .enumerate()
1607 .map(|(input, _)| {
1608 MirScalarExpr::Column(input_mapper.map_column_to_global(col, input))
1609 })
1610 .collect_vec()
1611 })
1612 .collect_vec();
1613 Ok(MirRelationExpr::join_scalars(join_inputs, equivalences))
1614 }
1615 })?;
1616 Ok((output, subquery_map))
1617 }
1618
1619 /// Rewrites `self` into a `mz_expr::ScalarExpr`.
1620 pub fn lower_uncorrelated(self) -> Result<MirScalarExpr, PlanError> {
1621 use MirScalarExpr as SS;
1622
1623 use HirScalarExpr::*;
1624
1625 Ok(match self {
1626 Column(ColumnRef { level: 0, column }) => SS::Column(column),
1627 Literal(datum, typ) => SS::Literal(Ok(datum), typ),
1628 CallUnmaterializable(func) => SS::CallUnmaterializable(func),
1629 CallUnary { func, expr } => SS::CallUnary {
1630 func,
1631 expr: Box::new(expr.lower_uncorrelated()?),
1632 },
1633 CallBinary { func, expr1, expr2 } => SS::CallBinary {
1634 func,
1635 expr1: Box::new(expr1.lower_uncorrelated()?),
1636 expr2: Box::new(expr2.lower_uncorrelated()?),
1637 },
1638 CallVariadic { func, exprs } => SS::CallVariadic {
1639 func,
1640 exprs: exprs
1641 .into_iter()
1642 .map(|expr| expr.lower_uncorrelated())
1643 .collect::<Result<_, _>>()?,
1644 },
1645 If { cond, then, els } => SS::If {
1646 cond: Box::new(cond.lower_uncorrelated()?),
1647 then: Box::new(then.lower_uncorrelated()?),
1648 els: Box::new(els.lower_uncorrelated()?),
1649 },
1650 Select { .. } | Exists { .. } | Parameter(..) | Column(..) | Windowing(..) => {
1651 sql_bail!("unexpected ScalarExpr in uncorrelated plan: {:?}", self);
1652 }
1653 })
1654 }
1655}
1656
1657/// Prepare to apply `inner` to `outer`. Note that `inner` is a correlated (SQL)
1658/// expression, while `outer` is a non-correlated (dataflow) expression. `inner`
1659/// will, in effect, be executed once for every distinct row in `outer`, and the
1660/// results will be joined with `outer`. Note that columns in `outer` that are
1661/// not depended upon by `inner` are thrown away before the distinct, so that we
1662/// don't perform needless computation of `inner`.
1663///
1664/// `branch` will inspect the contents of `inner` to determine whether `inner`
1665/// is not multiplicity sensitive (roughly, contains only maps, filters,
1666/// projections, and calls to table functions). If it is not multiplicity
1667/// sensitive, `branch` will *not* distinctify outer. If this is problematic,
1668/// e.g. because the `apply` callback itself introduces multiplicity-sensitive
1669/// operations that were not present in `inner`, then set
1670/// `apply_requires_distinct_outer` to ensure that `branch` chooses the plan
1671/// that distinctifies `outer`.
1672///
1673/// The caller must supply the `apply` function that applies the rewritten
1674/// `inner` to `outer`.
1675fn branch<F>(
1676 id_gen: &mut mz_ore::id_gen::IdGen,
1677 outer: MirRelationExpr,
1678 col_map: &ColumnMap,
1679 cte_map: &mut CteMap,
1680 inner: HirRelationExpr,
1681 apply_requires_distinct_outer: bool,
1682 context: &Context,
1683 apply: F,
1684) -> Result<MirRelationExpr, PlanError>
1685where
1686 F: FnOnce(
1687 &mut mz_ore::id_gen::IdGen,
1688 HirRelationExpr,
1689 MirRelationExpr,
1690 &ColumnMap,
1691 &mut CteMap,
1692 &Context,
1693 ) -> Result<MirRelationExpr, PlanError>,
1694{
1695 // TODO: It would be nice to have a version of this code w/o optimizations,
1696 // at the least for purposes of understanding. It was difficult for one reader
1697 // to understand the required properties of `outer` and `col_map`.
1698
1699 // If the inner expression is sufficiently simple, it is safe to apply it
1700 // *directly* to outer, rather than applying it to the distinctified key
1701 // (see below).
1702 //
1703 // As an example, consider the following two queries:
1704 //
1705 // CREATE TABLE t (a int, b int);
1706 // SELECT a, series FROM t, generate_series(1, t.b) series;
1707 //
1708 // The "simple" path for the `SELECT` yields
1709 //
1710 // %0 =
1711 // | Get t
1712 // | FlatMap generate_series(1, #1)
1713 //
1714 // while the non-simple path yields:
1715 //
1716 // %0 =
1717 // | Get t
1718 //
1719 // %1 =
1720 // | Get t
1721 // | Distinct group=(#1)
1722 // | FlatMap generate_series(1, #0)
1723 //
1724 // %2 =
1725 // | LeftJoin %1 %2 (= #1 #2)
1726 //
1727 // There is a tradeoff here: the simple plan is stateless, but the non-
1728 // simple plan may do (much) less computation if there are only a few
1729 // distinct values of `t.b`.
1730 //
1731 // We apply a very simple heuristic here and take the simple path if `inner`
1732 // contains only maps, filters, projections, and calls to table functions.
1733 // The intuition is that straightforward usage of table functions should
1734 // take the simple path, while everything else should not. (In theory we
1735 // think this transformation is valid as long as `inner` does not contain a
1736 // Reduce, Distinct, or TopK node, but it is not always an optimization in
1737 // the general case.)
1738 //
1739 // TODO(benesch): this should all be handled by a proper optimizer, but
1740 // detecting the moment of decorrelation in the optimizer right now is too
1741 // hard.
1742 let mut is_simple = true;
1743 #[allow(deprecated)]
1744 inner.visit(0, &mut |expr, _| match expr {
1745 HirRelationExpr::Constant { .. }
1746 | HirRelationExpr::Project { .. }
1747 | HirRelationExpr::Map { .. }
1748 | HirRelationExpr::Filter { .. }
1749 | HirRelationExpr::CallTable { .. } => (),
1750 _ => is_simple = false,
1751 });
1752 if is_simple && !apply_requires_distinct_outer {
1753 let new_col_map = col_map.enter_scope(outer.arity() - col_map.len());
1754 return outer.let_in(id_gen, |id_gen, get_outer| {
1755 apply(id_gen, inner, get_outer, &new_col_map, cte_map, context)
1756 });
1757 }
1758
1759 // The key consists of the columns from the outer expression upon which the
1760 // inner relation depends. We discover these dependencies by walking the
1761 // inner relation expression and looking for column references whose level
1762 // escapes inner.
1763 //
1764 // At the end of this process, `key` contains the decorrelated position of
1765 // each outer column, according to the passed-in `col_map`, and
1766 // `new_col_map` maps each outer column to its new ordinal position in key.
1767 let mut outer_cols = BTreeSet::new();
1768 #[allow(deprecated)]
1769 inner.visit_columns(0, &mut |depth, col| {
1770 // Test if the column reference escapes the subquery.
1771 if col.level > depth {
1772 outer_cols.insert(ColumnRef {
1773 level: col.level - depth,
1774 column: col.column,
1775 });
1776 }
1777 });
1778 // Collect all the outer columns referenced by any CTE referenced by
1779 // the inner relation.
1780 #[allow(deprecated)]
1781 inner.visit(0, &mut |e, _| match e {
1782 HirRelationExpr::Get {
1783 id: mz_expr::Id::Local(id),
1784 ..
1785 } => {
1786 if let Some(cte_desc) = cte_map.get(id) {
1787 let cte_outer_arity = cte_desc.outer_relation.arity();
1788 outer_cols.extend(
1789 col_map
1790 .inner
1791 .iter()
1792 .filter(|(_, position)| **position < cte_outer_arity)
1793 .map(|(c, _)| {
1794 // `col_map` maps column references to column positions in
1795 // `outer`'s projection.
1796 // `outer_cols` is meant to contain the external column
1797 // references in `inner`.
1798 // Since `inner` defines a new scope, any column reference
1799 // in `col_map` is one level deeper when seen from within
1800 // `inner`, hence the +1.
1801 ColumnRef {
1802 level: c.level + 1,
1803 column: c.column,
1804 }
1805 }),
1806 );
1807 }
1808 }
1809 HirRelationExpr::Let { id, .. } => {
1810 // Note: if ID uniqueness is not guaranteed, we can't use `visit` since
1811 // we would need to remove the old CTE with the same ID temporarily while
1812 // traversing the definition of the new CTE under the same ID.
1813 assert!(!cte_map.contains_key(id));
1814 }
1815 _ => {}
1816 });
1817 let mut new_col_map = BTreeMap::new();
1818 let mut key = vec![];
1819 for col in outer_cols {
1820 new_col_map.insert(col, key.len());
1821 key.push(col_map.get(&ColumnRef {
1822 // Note: `outer_cols` contains the external column references within `inner`.
1823 // We must compensate for `inner`'s scope when translating column references
1824 // as seen within `inner` to column references as seen from `outer`'s context,
1825 // hence the -1.
1826 level: col.level - 1,
1827 column: col.column,
1828 }));
1829 }
1830 let new_col_map = ColumnMap::new(new_col_map);
1831 outer.let_in(id_gen, |id_gen, get_outer| {
1832 let keyed_outer = if key.is_empty() {
1833 // Don't depend on outer at all if the branch is not correlated,
1834 // which yields vastly better query plans. Note that this is a bit
1835 // weird in that the branch will be computed even if outer has no
1836 // rows, whereas if it had been correlated it would not (and *could*
1837 // not) have been computed if outer had no rows, but the callers of
1838 // this function don't mind these somewhat-weird semantics.
1839 MirRelationExpr::constant(vec![vec![]], RelationType::new(vec![]))
1840 } else {
1841 get_outer.clone().distinct_by(key.clone())
1842 };
1843 keyed_outer.let_in(id_gen, |id_gen, get_keyed_outer| {
1844 let oa = get_outer.arity();
1845 let branch = apply(
1846 id_gen,
1847 inner,
1848 get_keyed_outer,
1849 &new_col_map,
1850 cte_map,
1851 context,
1852 )?;
1853 let ba = branch.arity();
1854 let joined = MirRelationExpr::join(
1855 vec![get_outer.clone(), branch],
1856 key.iter()
1857 .enumerate()
1858 .map(|(i, &k)| vec![(0, k), (1, i)])
1859 .collect(),
1860 )
1861 // throw away the right-hand copy of the key we just joined on
1862 .project((0..oa).chain((oa + key.len())..(oa + ba)).collect());
1863 Ok(joined)
1864 })
1865 })
1866}
1867
1868fn apply_scalar_subquery(
1869 id_gen: &mut mz_ore::id_gen::IdGen,
1870 outer: MirRelationExpr,
1871 col_map: &ColumnMap,
1872 cte_map: &mut CteMap,
1873 scalar_subquery: HirRelationExpr,
1874 apply_requires_distinct_outer: bool,
1875 context: &Context,
1876) -> Result<MirRelationExpr, PlanError> {
1877 branch(
1878 id_gen,
1879 outer,
1880 col_map,
1881 cte_map,
1882 scalar_subquery,
1883 apply_requires_distinct_outer,
1884 context,
1885 |id_gen, expr, get_inner, col_map, cte_map, context| {
1886 // compute for every row in get_inner
1887 let select = expr.applied_to(id_gen, get_inner.clone(), col_map, cte_map, context)?;
1888 let col_type = select.typ().column_types.into_last();
1889
1890 let inner_arity = get_inner.arity();
1891 // We must determine a count for each `get_inner` prefix,
1892 // and report an error if that count exceeds one.
1893 let guarded = select.let_in(id_gen, |_id_gen, get_select| {
1894 // Count for each `get_inner` prefix.
1895 let counts = get_select.clone().reduce(
1896 (0..inner_arity).collect::<Vec<_>>(),
1897 vec![mz_expr::AggregateExpr {
1898 func: mz_expr::AggregateFunc::Count,
1899 expr: MirScalarExpr::literal_true(),
1900 distinct: false,
1901 }],
1902 None,
1903 );
1904 // Errors should result from counts > 1.
1905 let errors = counts
1906 .filter(vec![MirScalarExpr::Column(inner_arity).call_binary(
1907 MirScalarExpr::literal_ok(Datum::Int64(1), ScalarType::Int64),
1908 mz_expr::BinaryFunc::Gt,
1909 )])
1910 .project((0..inner_arity).collect::<Vec<_>>())
1911 .map_one(MirScalarExpr::literal(
1912 Err(mz_expr::EvalError::MultipleRowsFromSubquery),
1913 col_type.clone().scalar_type,
1914 ));
1915 // Return `get_select` and any errors added in.
1916 Ok::<_, PlanError>(get_select.union(errors))
1917 })?;
1918 // append Null to anything that didn't return any rows
1919 let default = vec![(Datum::Null, col_type.scalar_type)];
1920 get_inner.lookup(id_gen, guarded, default)
1921 },
1922 )
1923}
1924
1925fn apply_existential_subquery(
1926 id_gen: &mut mz_ore::id_gen::IdGen,
1927 outer: MirRelationExpr,
1928 col_map: &ColumnMap,
1929 cte_map: &mut CteMap,
1930 subquery_expr: HirRelationExpr,
1931 apply_requires_distinct_outer: bool,
1932 context: &Context,
1933) -> Result<MirRelationExpr, PlanError> {
1934 branch(
1935 id_gen,
1936 outer,
1937 col_map,
1938 cte_map,
1939 subquery_expr,
1940 apply_requires_distinct_outer,
1941 context,
1942 |id_gen, expr, get_inner, col_map, cte_map, context| {
1943 let exists = expr
1944 // compute for every row in get_inner
1945 .applied_to(id_gen, get_inner.clone(), col_map, cte_map, context)?
1946 // throw away actual values and just remember whether or not there were __any__ rows
1947 .distinct_by((0..get_inner.arity()).collect())
1948 // Append true to anything that returned any rows.
1949 .map(vec![MirScalarExpr::literal_true()]);
1950
1951 // append False to anything that didn't return any rows
1952 get_inner.lookup(id_gen, exists, vec![(Datum::False, ScalarType::Bool)])
1953 },
1954 )
1955}
1956
1957impl AggregateExpr {
1958 fn applied_to(
1959 self,
1960 id_gen: &mut mz_ore::id_gen::IdGen,
1961 col_map: &ColumnMap,
1962 cte_map: &mut CteMap,
1963 inner: &mut MirRelationExpr,
1964 context: &Context,
1965 ) -> Result<mz_expr::AggregateExpr, PlanError> {
1966 let AggregateExpr {
1967 func,
1968 expr,
1969 distinct,
1970 } = self;
1971
1972 Ok(mz_expr::AggregateExpr {
1973 func: func.into_expr(),
1974 expr: expr.applied_to(id_gen, col_map, cte_map, inner, &None, context)?,
1975 distinct,
1976 })
1977 }
1978}
1979
1980/// Attempts an efficient outer join, if `on` has equijoin structure.
1981///
1982/// Both `left` and `right` are decorrelated inputs.
1983///
1984/// The first `oa` columns correspond to an outer context: we should do the
1985/// outer join independently for each prefix. In the case that `on` contains
1986/// just some equality tests between columns of `left` and `right` and some
1987/// local predicates, we can employ a relatively simple plan.
1988///
1989/// The last `on_subquery_types.len()` columns correspond to results from
1990/// subqueries defined in the `on` clause - we treat those as theta-join
1991/// conditions that prohibit the use of the simple plan attempted here.
1992fn attempt_outer_equijoin(
1993 left: MirRelationExpr,
1994 right: MirRelationExpr,
1995 on: MirScalarExpr,
1996 on_subquery_types: Vec<ColumnType>,
1997 kind: JoinKind,
1998 oa: usize,
1999 id_gen: &mut mz_ore::id_gen::IdGen,
2000 context: &Context,
2001) -> Result<Option<MirRelationExpr>, PlanError> {
2002 // TODO(database-issues#6827): In theory, we can be smarter and also handle `on`
2003 // predicates that reference subqueries as long as these subqueries don't
2004 // reference `left` and `right` at the same time.
2005 //
2006 // TODO(database-issues#6828): This code can be improved as follows:
2007 //
2008 // 1. Move the `canonicalize_predicates(...)` call to `applied_to`.
2009 // 2. Use the canonicalized `on` predicate in the non-equijoin based
2010 // lowering strategy.
2011 // 3. Move the `OnPredicates::new(...)` call to `applied_to`.
2012 // 4. Pass the classified `OnPredicates` as a parameter.
2013 // 5. Guard calls of this function with `on_predicates.is_equijoin()`.
2014 //
2015 // Steps (1 + 2) require further investigation because we might change the
2016 // error semantics in case the `on` predicate contains a literal error..
2017
2018 let l_type = left.typ();
2019 let r_type = right.typ();
2020 let la = l_type.column_types.len() - oa;
2021 let ra = r_type.column_types.len() - oa;
2022 let sa = on_subquery_types.len();
2023
2024 // The output type contains [outer, left, right, sa] attributes.
2025 let mut output_type = Vec::with_capacity(oa + la + ra + sa);
2026 output_type.extend(l_type.column_types);
2027 output_type.extend(r_type.column_types.into_iter().skip(oa));
2028 output_type.extend(on_subquery_types);
2029
2030 // Generally healthy to do, but specifically `USING` conditions sometimes
2031 // put an `AND true` at the end of the `ON` condition.
2032 //
2033 // TODO(aalexandrov): maybe we should already be doing this in `applied_to`.
2034 // However, in that case it's not clear that we won't see regressions if
2035 // `on` simplifies to a literal error.
2036 let mut on = vec![on];
2037 mz_expr::canonicalize::canonicalize_predicates(&mut on, &output_type);
2038
2039 // Form the left and right types without the outer attributes.
2040 output_type.drain(0..oa);
2041 let lt = output_type.drain(0..la).collect_vec();
2042 let rt = output_type.drain(0..ra).collect_vec();
2043 assert!(output_type.len() == sa);
2044
2045 let on_predicates = OnPredicates::new(oa, la, ra, sa, on.clone(), context);
2046 if !on_predicates.is_equijoin(context) {
2047 return Ok(None);
2048 }
2049
2050 // If we've gotten this far, we can do the clever thing.
2051 // We'll want to use left and right multiple times
2052 let result = left.let_in(id_gen, |id_gen, get_left| {
2053 right.let_in(id_gen, |id_gen, get_right| {
2054 // TODO: we know that we can re-use the arrangements of left and right
2055 // needed for the inner join with each of the conditional outer joins.
2056 // It is not clear whether we should hint that, or just let the planner
2057 // and optimizer run and see what happens.
2058
2059 // We'll want the inner join (minus repeated columns)
2060 let join = MirRelationExpr::join(
2061 vec![get_left.clone(), get_right.clone()],
2062 (0..oa).map(|i| vec![(0, i), (1, i)]).collect(),
2063 )
2064 // remove those columns from `right` repeating the first `oa` columns.
2065 .project(
2066 (0..(oa + la))
2067 .chain((oa + la + oa)..(oa + la + oa + ra))
2068 .collect(),
2069 )
2070 // apply the filter constraints here, to ensure nulls are not matched.
2071 .filter(on);
2072
2073 // We'll want to re-use the results of the join multiple times.
2074 join.let_in(id_gen, |id_gen, get_join| {
2075 let mut result = get_join.clone();
2076
2077 // A collection of keys present in both left and right collections.
2078 let join_keys = on_predicates.join_keys();
2079 let both_keys_arity = join_keys.len();
2080 let both_keys = get_join.restrict(join_keys).distinct();
2081
2082 // The plan is now to determine the left and right rows matched in the
2083 // inner join, subtract them from left and right respectively, pad what
2084 // remains with nulls, and fold them in to `result`.
2085
2086 both_keys.let_in(id_gen, |_id_gen, get_both| {
2087 if let JoinKind::LeftOuter { .. } | JoinKind::FullOuter = kind {
2088 // Rows in `left` matched in the inner equijoin. This is
2089 // a semi-join between `left` and `both_keys`.
2090 let left_present = MirRelationExpr::join_scalars(
2091 vec![
2092 get_left
2093 .clone()
2094 // Push local predicates.
2095 .filter(on_predicates.lhs()),
2096 get_both.clone(),
2097 ],
2098 itertools::zip_eq(
2099 on_predicates.eq_lhs(),
2100 (0..both_keys_arity).map(|k| MirScalarExpr::column(oa + la + k)),
2101 )
2102 .map(|(l_key, b_key)| [l_key, b_key].to_vec())
2103 .collect(),
2104 )
2105 .project((0..(oa + la)).collect());
2106
2107 // Determine the types of nulls to use as filler.
2108 let right_fill = rt
2109 .into_iter()
2110 .map(|typ| MirScalarExpr::literal_null(typ.scalar_type))
2111 .collect();
2112
2113 // Add to `result` absent elements, filled with typed nulls.
2114 result = left_present
2115 .negate()
2116 .union(get_left.clone())
2117 .map(right_fill)
2118 .union(result);
2119 }
2120
2121 if let JoinKind::RightOuter | JoinKind::FullOuter = kind {
2122 // Rows in `right` matched in the inner equijoin. This
2123 // is a semi-join between `right` and `both_keys`.
2124 let right_present = MirRelationExpr::join_scalars(
2125 vec![
2126 get_right
2127 .clone()
2128 // Push local predicates.
2129 .filter(on_predicates.rhs()),
2130 get_both,
2131 ],
2132 itertools::zip_eq(
2133 on_predicates.eq_rhs(),
2134 (0..both_keys_arity).map(|k| MirScalarExpr::column(oa + ra + k)),
2135 )
2136 .map(|(r_key, b_key)| [r_key, b_key].to_vec())
2137 .collect(),
2138 )
2139 .project((0..(oa + ra)).collect());
2140
2141 // Determine the types of nulls to use as filler.
2142 let left_fill = lt
2143 .into_iter()
2144 .map(|typ| MirScalarExpr::literal_null(typ.scalar_type))
2145 .collect();
2146
2147 // Add to `result` absent elements, prepended with typed nulls.
2148 result = right_present
2149 .negate()
2150 .union(get_right.clone())
2151 .map(left_fill)
2152 // Permute left fill before right values.
2153 .project(
2154 itertools::chain!(
2155 0..oa, // Preserve `outer`.
2156 oa + ra..oa + la + ra, // Increment the next `la` cols by `ra`.
2157 oa..oa + ra // Decrement the next `ra` cols by `la`.
2158 )
2159 .collect(),
2160 )
2161 .union(result)
2162 }
2163
2164 Ok::<_, PlanError>(result)
2165 })
2166 })
2167 })
2168 })?;
2169 Ok(Some(result))
2170}
2171
2172/// A struct that represents the predicates in the `on` clause in a form
2173/// suitable for efficient planning outer joins with equijoin predicates.
2174struct OnPredicates {
2175 /// A store for classified `ON` predicates.
2176 ///
2177 /// Predicates that reference a single side are adjusted to assume an
2178 /// `outer × <side>` schema.
2179 predicates: Vec<OnPredicate>,
2180 /// Number of outer context columns.
2181 oa: usize,
2182}
2183
2184impl OnPredicates {
2185 const I_OUT: usize = 0; // outer context input position
2186 const I_LHS: usize = 1; // lhs input position
2187 const I_RHS: usize = 2; // rhs input position
2188 const I_SUB: usize = 3; // on subqueries input position
2189
2190 /// Classify the predicates in the `on` clause of an outer join.
2191 ///
2192 /// The other parameters are arities of the input parts:
2193 ///
2194 /// - `oa` is the arity of the `outer` context.
2195 /// - `la` is the arity of the `left` input.
2196 /// - `ra` is the arity of the `right` input.
2197 /// - `sa` is the arity of the `on` subqueries.
2198 ///
2199 /// The constructor assumes that:
2200 ///
2201 /// 1. The `on` parameter will be applied on a result that has the following
2202 /// schema `outer × left × right × on_subqueries`.
2203 /// 2. The `on` parameter is already adjusted to assume that schema.
2204 /// 3. The `on` parameter is obtained by canonicalizing the original `on:
2205 /// MirScalarExpr` with `canonicalize_predicates`.
2206 fn new(
2207 oa: usize,
2208 la: usize,
2209 ra: usize,
2210 sa: usize,
2211 on: Vec<MirScalarExpr>,
2212 _context: &Context,
2213 ) -> Self {
2214 use mz_expr::BinaryFunc::Eq;
2215
2216 // Re-bind those locally for more compact pattern matching.
2217 const I_LHS: usize = OnPredicates::I_LHS;
2218 const I_RHS: usize = OnPredicates::I_RHS;
2219
2220 // Self parameters.
2221 let mut predicates = Vec::with_capacity(on.len());
2222
2223 // Helpers for populating `predicates`.
2224 let inner_join_mapper = mz_expr::JoinInputMapper::new_from_input_arities([oa, la, ra, sa]);
2225 let rhs_permutation = itertools::chain!(0..oa + la, oa..oa + ra).collect::<Vec<_>>();
2226 let lookup_inputs = |expr: &MirScalarExpr| -> Vec<usize> {
2227 inner_join_mapper
2228 .lookup_inputs(expr)
2229 .filter(|&i| i != Self::I_OUT)
2230 .collect()
2231 };
2232 let has_subquery_refs = |expr: &MirScalarExpr| -> bool {
2233 inner_join_mapper
2234 .lookup_inputs(expr)
2235 .any(|i| i == Self::I_SUB)
2236 };
2237
2238 // Iterate over `on` elements and populate `predicates`.
2239 for mut predicate in on {
2240 if predicate.might_error() {
2241 tracing::debug!(case = "thetajoin (error)", "OnPredicates::new");
2242 // Treat predicates that can produce a literal error as Theta.
2243 predicates.push(OnPredicate::Theta(predicate));
2244 } else if has_subquery_refs(&predicate) {
2245 tracing::debug!(case = "thetajoin (subquery)", "OnPredicates::new");
2246 // Treat predicates referencing an `on` subquery as Theta.
2247 predicates.push(OnPredicate::Theta(predicate));
2248 } else if let MirScalarExpr::CallBinary {
2249 func: Eq,
2250 expr1,
2251 expr2,
2252 } = &mut predicate
2253 {
2254 // Obtain the non-outer inputs referenced by each side.
2255 let inputs1 = lookup_inputs(expr1);
2256 let inputs2 = lookup_inputs(expr2);
2257
2258 match (&inputs1[..], &inputs2[..]) {
2259 // Neither side references an input. This could be a
2260 // constant expression or an expression that depends only on
2261 // the outer context.
2262 ([], []) => {
2263 predicates.push(OnPredicate::Const(predicate));
2264 }
2265 // Both sides reference different inputs.
2266 ([I_LHS], [I_RHS]) => {
2267 let lhs = expr1.take();
2268 let mut rhs = expr2.take();
2269 rhs.permute(&rhs_permutation);
2270 predicates.push(OnPredicate::Eq(lhs.clone(), rhs.clone()));
2271 predicates.push(OnPredicate::LhsConsequence(lhs.call_is_null().not()));
2272 predicates.push(OnPredicate::RhsConsequence(rhs.call_is_null().not()));
2273 }
2274 // Both sides reference different inputs (swapped).
2275 ([I_RHS], [I_LHS]) => {
2276 let lhs = expr2.take();
2277 let mut rhs = expr1.take();
2278 rhs.permute(&rhs_permutation);
2279 predicates.push(OnPredicate::Eq(lhs.clone(), rhs.clone()));
2280 predicates.push(OnPredicate::LhsConsequence(lhs.call_is_null().not()));
2281 predicates.push(OnPredicate::RhsConsequence(rhs.call_is_null().not()));
2282 }
2283 // Both sides reference the left input or no input.
2284 ([I_LHS], [I_LHS]) | ([I_LHS], []) | ([], [I_LHS]) => {
2285 predicates.push(OnPredicate::Lhs(predicate));
2286 }
2287 // Both sides reference the right input or no input.
2288 ([I_RHS], [I_RHS]) | ([I_RHS], []) | ([], [I_RHS]) => {
2289 predicate.permute(&rhs_permutation);
2290 predicates.push(OnPredicate::Rhs(predicate));
2291 }
2292 // At least one side references more than one input.
2293 _ => {
2294 tracing::debug!(case = "thetajoin (eq)", "OnPredicates::new");
2295 predicates.push(OnPredicate::Theta(predicate));
2296 }
2297 }
2298 } else {
2299 // Obtain the non-outer inputs referenced by this predicate.
2300 let inputs = lookup_inputs(&predicate);
2301
2302 match &inputs[..] {
2303 // The predicate references no inputs. This could be a
2304 // constant expression or an expression that depends only on
2305 // the outer context.
2306 [] => {
2307 predicates.push(OnPredicate::Const(predicate));
2308 }
2309 // The predicate references only the left input.
2310 [I_LHS] => {
2311 predicates.push(OnPredicate::Lhs(predicate));
2312 }
2313 // The predicate references only the right input.
2314 [I_RHS] => {
2315 predicate.permute(&rhs_permutation);
2316 predicates.push(OnPredicate::Rhs(predicate));
2317 }
2318 // The predicate references both inputs.
2319 _ => {
2320 tracing::debug!(case = "thetajoin (non-eq)", "OnPredicates::new");
2321 predicates.push(OnPredicate::Theta(predicate));
2322 }
2323 }
2324 }
2325 }
2326
2327 Self { predicates, oa }
2328 }
2329
2330 /// Check if the predicates can be lowered with an equijoin-based strategy.
2331 fn is_equijoin(&self, context: &Context) -> bool {
2332 // Count each `OnPredicate` variant in `self.predicates`.
2333 let (const_cnt, lhs_cnt, rhs_cnt, eq_cnt, eq_cols, theta_cnt) =
2334 self.predicates.iter().fold(
2335 (0, 0, 0, 0, 0, 0),
2336 |(const_cnt, lhs_cnt, rhs_cnt, eq_cnt, eq_cols, theta_cnt), p| {
2337 (
2338 const_cnt + usize::from(matches!(p, OnPredicate::Const(..))),
2339 lhs_cnt + usize::from(matches!(p, OnPredicate::Lhs(..))),
2340 rhs_cnt + usize::from(matches!(p, OnPredicate::Rhs(..))),
2341 eq_cnt + usize::from(matches!(p, OnPredicate::Eq(..))),
2342 eq_cols + usize::from(matches!(p, OnPredicate::Eq(lhs, rhs) if lhs.is_column() && rhs.is_column())),
2343 theta_cnt + usize::from(matches!(p, OnPredicate::Theta(..))),
2344 )
2345 },
2346 );
2347
2348 let is_equijion = if context.config.enable_new_outer_join_lowering {
2349 // New classifier.
2350 eq_cnt > 0 && theta_cnt == 0
2351 } else {
2352 // Old classifier.
2353 eq_cnt > 0 && eq_cnt == eq_cols && theta_cnt + const_cnt + lhs_cnt + rhs_cnt == 0
2354 };
2355
2356 // Log an entry only if this is an equijoin according to the new classifier.
2357 if eq_cnt > 0 && theta_cnt == 0 {
2358 tracing::debug!(
2359 const_cnt,
2360 lhs_cnt,
2361 rhs_cnt,
2362 eq_cnt,
2363 eq_cols,
2364 theta_cnt,
2365 "OnPredicates::is_equijoin"
2366 );
2367 }
2368
2369 is_equijion
2370 }
2371
2372 /// Return an [`MirRelationExpr`] list that represents the keys for the
2373 /// equijoin. The list will contain the outer columns as a prefix.
2374 fn join_keys(&self) -> JoinKeys {
2375 // We could return either the `lhs` or the `rhs` of the keys used to
2376 // form the inner join as they are equated by the join condition.
2377 let join_keys = self.eq_lhs().collect::<Vec<_>>();
2378
2379 if join_keys.iter().all(|k| k.is_column()) {
2380 tracing::debug!(case = "outputs", "OnPredicates::join_keys");
2381 JoinKeys::Outputs(join_keys.iter().flat_map(|k| k.as_column()).collect())
2382 } else {
2383 tracing::debug!(case = "scalars", "OnPredicates::join_keys");
2384 JoinKeys::Scalars(join_keys)
2385 }
2386 }
2387
2388 /// Return an iterator over the left-hand sides of all [`OnPredicate::Eq`]
2389 /// conditions in the predicates list.
2390 ///
2391 /// The iterator will start with column references to the outer columns as a
2392 /// prefix.
2393 fn eq_lhs(&self) -> impl Iterator<Item = MirScalarExpr> + '_ {
2394 itertools::chain(
2395 (0..self.oa).map(MirScalarExpr::column),
2396 self.predicates.iter().filter_map(|e| match e {
2397 OnPredicate::Eq(lhs, _) => Some(lhs.clone()),
2398 _ => None,
2399 }),
2400 )
2401 }
2402
2403 /// Return an iterator over the right-hand sides of all [`OnPredicate::Eq`]
2404 /// conditions in the predicates list.
2405 ///
2406 /// The iterator will start with column references to the outer columns as a
2407 /// prefix.
2408 fn eq_rhs(&self) -> impl Iterator<Item = MirScalarExpr> + '_ {
2409 itertools::chain(
2410 (0..self.oa).map(MirScalarExpr::column),
2411 self.predicates.iter().filter_map(|e| match e {
2412 OnPredicate::Eq(_, rhs) => Some(rhs.clone()),
2413 _ => None,
2414 }),
2415 )
2416 }
2417
2418 /// Return an iterator over the [`OnPredicate::Lhs`], [`OnPredicate::LhsConsequence`] and
2419 /// [`OnPredicate::Const`] conditions in the predicates list.
2420 fn lhs(&self) -> impl Iterator<Item = MirScalarExpr> + '_ {
2421 self.predicates.iter().filter_map(|p| match p {
2422 // We treat Const predicates local to both inputs.
2423 OnPredicate::Const(p) => Some(p.clone()),
2424 OnPredicate::Lhs(p) => Some(p.clone()),
2425 OnPredicate::LhsConsequence(p) => Some(p.clone()),
2426 _ => None,
2427 })
2428 }
2429
2430 /// Return an iterator over the [`OnPredicate::Rhs`], [`OnPredicate::RhsConsequence`] and
2431 /// [`OnPredicate::Const`] conditions in the predicates list.
2432 fn rhs(&self) -> impl Iterator<Item = MirScalarExpr> + '_ {
2433 self.predicates.iter().filter_map(|p| match p {
2434 // We treat Const predicates local to both inputs.
2435 OnPredicate::Const(p) => Some(p.clone()),
2436 OnPredicate::Rhs(p) => Some(p.clone()),
2437 OnPredicate::RhsConsequence(p) => Some(p.clone()),
2438 _ => None,
2439 })
2440 }
2441}
2442
2443enum OnPredicate {
2444 // A predicate that is either constant or references only outer columns.
2445 Const(MirScalarExpr),
2446 // A local predicate on the left-hand side of the join, i.e., it references only the left input
2447 // and possibly outer columns.
2448 //
2449 // This is one of the original predicates from the ON clause.
2450 //
2451 // One _must_ apply this predicate.
2452 Lhs(MirScalarExpr),
2453 // A local predicate on the left-hand side of the join, i.e., it references only the left input
2454 // and possibly outer columns.
2455 //
2456 // This is not one of the original predicates from the ON clause, but is just a consequence
2457 // of an original predicate in the ON clause, where the original predicate references both
2458 // inputs, but the consequence references only the left input.
2459 //
2460 // For example, the original predicate `input1.x = input2.a` has the consequence
2461 // `input1.x IS NOT NULL`. Applying such a consequence before the input is fed into the join
2462 // prevents null skew, and also makes more CSE opportunities available when the left input's key
2463 // doesn't have a NOT NULL constraint, saving us an arrangement.
2464 //
2465 // Applying the predicate is optional, because the original predicate will be applied anyway.
2466 LhsConsequence(MirScalarExpr),
2467 // A local predicate on the right-hand side of the join.
2468 //
2469 // This is one of the original predicates from the ON clause.
2470 //
2471 // One _must_ apply this predicate.
2472 Rhs(MirScalarExpr),
2473 // A consequence of an original ON predicate, see above.
2474 RhsConsequence(MirScalarExpr),
2475 // An equality predicate between the two sides.
2476 Eq(MirScalarExpr, MirScalarExpr),
2477 // a non-equality predicate between the two sides.
2478 #[allow(dead_code)]
2479 Theta(MirScalarExpr),
2480}
2481
2482/// A set of join keys referencing an input.
2483///
2484/// This is used in the [`MirRelationExpr::Join`] lowering code in order to
2485/// avoid changes (and thereby possible regressions) in plans that have equijoin
2486/// predicates consisting only of column refs.
2487///
2488/// If we were running `CanonicalizeMfp` as part of `NormalizeOps` we might be
2489/// able to get rid of this code, but as it stands `Map` simplification seems
2490/// more cumbersome than `Project` simplification, so do this just to be sure.
2491enum JoinKeys {
2492 // A predicate that is either constant or references only outer columns.
2493 Outputs(Vec<usize>),
2494 // A local predicate on the left-hand side of the join.
2495 Scalars(Vec<MirScalarExpr>),
2496}
2497
2498impl JoinKeys {
2499 fn len(&self) -> usize {
2500 match self {
2501 JoinKeys::Outputs(outputs) => outputs.len(),
2502 JoinKeys::Scalars(scalars) => scalars.len(),
2503 }
2504 }
2505}
2506
2507/// Extension methods for [`MirRelationExpr`] required in the HIR ⇒ MIR lowering
2508/// code.
2509trait LoweringExt {
2510 /// See [`MirRelationExpr::restrict`].
2511 fn restrict(self, join_keys: JoinKeys) -> Self;
2512}
2513
2514impl LoweringExt for MirRelationExpr {
2515 /// Restrict the set of columns of an input to the sequence of [`JoinKeys`].
2516 fn restrict(self, join_keys: JoinKeys) -> Self {
2517 let num_keys = join_keys.len();
2518 match join_keys {
2519 JoinKeys::Outputs(outputs) => self.project(outputs),
2520 JoinKeys::Scalars(scalars) => {
2521 let input_arity = self.arity();
2522 let outputs = (input_arity..input_arity + num_keys).collect();
2523 self.map(scalars).project(outputs)
2524 }
2525 }
2526 }
2527}