mz_sql/plan/lowering.rs
1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Lowering is the process of transforming a `HirRelationExpr`
11//! into a `MirRelationExpr`.
12//!
13//! The most crucial part of lowering is decorrelation; i.e.: rewriting a
14//! `HirScalarExpr` that may contain subqueries (e.g. `SELECT` or `EXISTS`)
15//! with instances of `MirScalarExpr` that contain none of these.
16//!
17//! Informally, a subquery should be viewed as a query that is executed in
18//! the context of some outer relation, for each row of that relation. The
19//! subqueries often contain references to the columns of the outer
20//! relation.
21//!
22//! The transformation we perform maintains an `outer` relation and then
23//! traverses the relation expression that may contain references to those
24//! outer columns. As subqueries are discovered, the current relation
25//! expression is recast as the outer expression until such a point as the
26//! scalar expression's evaluation can be determined and appended to each
27//! row of the previously outer relation.
28//!
29//! It is important that the outer columns (the initial columns) act as keys
30//! for all nested computation. When counts or other aggregations are
31//! performed, they should include not only the indicated keys but also all
32//! of the outer columns.
33//!
34//! The decorrelation transformation is initialized with an empty outer
35//! relation, but it seems entirely appropriate to decorrelate queries that
36//! contain "holes" from prepared statements, as if the query was a subquery
37//! against a relation containing the assignments of values to those holes.
38
39use std::collections::{BTreeMap, BTreeSet};
40use std::iter::repeat;
41
42use itertools::Itertools;
43use mz_expr::func::variadic;
44use mz_expr::visit::Visit;
45use mz_expr::{AccessStrategy, AggregateFunc, MirRelationExpr, MirScalarExpr, func};
46use mz_ore::collections::CollectionExt;
47use mz_ore::stack::maybe_grow;
48use mz_repr::*;
49
50use crate::optimizer_metrics::OptimizerMetrics;
51use crate::plan::hir::{
52 AggregateExpr, ColumnOrder, ColumnRef, HirRelationExpr, HirScalarExpr, JoinKind, WindowExprType,
53};
54use crate::plan::{PlanError, transform_hir};
55use crate::session::vars::SystemVars;
56
57mod variadic_left;
58
59/// Maps a leveled column reference to a specific column.
60///
61/// Leveled column references are nested, so that larger levels are
62/// found early in a record and level zero is found at the end.
63///
64/// The column map only stores references for levels greater than zero,
65/// and column references at level zero simply start at the first column
66/// after all prior references.
67#[derive(Debug, Clone)]
68struct ColumnMap {
69 inner: BTreeMap<ColumnRef, usize>,
70}
71
72impl ColumnMap {
73 fn empty() -> ColumnMap {
74 Self::new(BTreeMap::new())
75 }
76
77 fn new(inner: BTreeMap<ColumnRef, usize>) -> ColumnMap {
78 ColumnMap { inner }
79 }
80
81 fn get(&self, col_ref: &ColumnRef) -> usize {
82 if col_ref.level == 0 {
83 self.inner.len() + col_ref.column
84 } else {
85 self.inner[col_ref]
86 }
87 }
88
89 fn len(&self) -> usize {
90 self.inner.len()
91 }
92
93 /// Updates references in the `ColumnMap` for use in a nested scope. The
94 /// provided `arity` must specify the arity of the current scope.
95 fn enter_scope(&self, arity: usize) -> ColumnMap {
96 // From the perspective of the nested scope, all existing column
97 // references will be one level greater.
98 let existing = self
99 .inner
100 .clone()
101 .into_iter()
102 .update(|(col, _i)| col.level += 1);
103
104 // All columns in the current scope become explicit entries in the
105 // immediate parent scope.
106 let new = (0..arity).map(|i| {
107 (
108 ColumnRef {
109 level: 1,
110 column: i,
111 },
112 self.len() + i,
113 )
114 });
115
116 ColumnMap::new(existing.chain(new).collect())
117 }
118}
119
120/// Map with the CTEs currently in scope.
121type CteMap = BTreeMap<mz_expr::LocalId, CteDesc>;
122
123/// Information about needed when finding a reference to a CTE in scope.
124#[derive(Clone)]
125struct CteDesc {
126 /// The new ID assigned to the lowered version of the CTE, which may not match
127 /// the ID of the input CTE.
128 new_id: mz_expr::LocalId,
129 /// The relation type of the CTE including the columns from the outer
130 /// context at the beginning.
131 relation_type: ReprRelationType,
132 /// The outer relation the CTE was applied to.
133 outer_relation: MirRelationExpr,
134}
135
136#[derive(Debug, Clone, Copy)]
137pub struct Config {
138 /// Enable outer join lowering implemented in database-issues#6747.
139 pub enable_new_outer_join_lowering: bool,
140 /// Enable outer join lowering implemented in database-issues#7561.
141 pub enable_variadic_left_join_lowering: bool,
142 pub enable_guard_subquery_tablefunc: bool,
143 pub enable_cast_elimination: bool,
144}
145
146impl Default for Config {
147 fn default() -> Self {
148 Self {
149 enable_new_outer_join_lowering: false,
150 enable_variadic_left_join_lowering: false,
151 enable_guard_subquery_tablefunc: false,
152 enable_cast_elimination: false,
153 }
154 }
155}
156
157impl From<&SystemVars> for Config {
158 fn from(vars: &SystemVars) -> Self {
159 Self {
160 enable_new_outer_join_lowering: vars.enable_new_outer_join_lowering(),
161 enable_variadic_left_join_lowering: vars.enable_variadic_left_join_lowering(),
162 enable_guard_subquery_tablefunc: vars.enable_guard_subquery_tablefunc(),
163 enable_cast_elimination: vars.enable_cast_elimination(),
164 }
165 }
166}
167
168/// Context passed to the lowering. This is wired to most parts of the lowering.
169pub(crate) struct Context<'a> {
170 /// Feature flags affecting the behavior of lowering.
171 pub config: &'a Config,
172 /// Optional, because some callers don't have an `OptimizerMetrics` handy. When it's None, we
173 /// simply don't write metrics.
174 pub metrics: Option<&'a OptimizerMetrics>,
175}
176
177impl HirRelationExpr {
178 /// Rewrite `self` into a `MirRelationExpr`.
179 /// This requires rewriting all correlated subqueries (nested `HirRelationExpr`s) into flat queries
180 #[mz_ore::instrument(target = "optimizer", level = "trace", name = "hir_to_mir")]
181 pub fn lower<C: Into<Config>>(
182 self,
183 config: C,
184 metrics: Option<&OptimizerMetrics>,
185 ) -> Result<MirRelationExpr, PlanError> {
186 let context = Context {
187 config: &config.into(),
188 metrics,
189 };
190 let result = match self {
191 // We directly rewrite a Constant into the corresponding `MirRelationExpr::Constant`
192 // to ensure that the downstream optimizer can easily bypass most
193 // irrelevant optimizations (e.g. reduce folding) for this expression
194 // without having to re-learn the fact that it is just a constant,
195 // as it would if the constant were wrapped in a Let-Get pair.
196 HirRelationExpr::Constant { rows, typ } => {
197 let rows: Vec<_> = rows.into_iter().map(|row| (row, Diff::ONE)).collect();
198 MirRelationExpr::Constant {
199 rows: Ok(rows),
200 typ: ReprRelationType::from(&typ),
201 }
202 }
203 mut other => {
204 let mut id_gen = mz_ore::id_gen::IdGen::default();
205 transform_hir::split_subquery_predicates(&mut other)?;
206 transform_hir::try_simplify_quantified_comparisons(&mut other)?;
207 transform_hir::fuse_window_functions(&mut other, &context)?;
208 MirRelationExpr::constant(vec![vec![]], ReprRelationType::new(vec![])).let_in(
209 &mut id_gen,
210 |id_gen, get_outer| {
211 other.applied_to(
212 id_gen,
213 get_outer,
214 &ColumnMap::empty(),
215 &mut CteMap::new(),
216 &context,
217 )
218 },
219 )?
220 }
221 };
222
223 mz_repr::explain::trace_plan(&result);
224
225 Ok(result)
226 }
227
228 /// Return a `MirRelationExpr` which evaluates `self` once for each row of `get_outer`.
229 ///
230 /// For uncorrelated `self`, this should be the cross-product between `get_outer` and `self`.
231 /// When `self` references columns of `get_outer`, much more work needs to occur.
232 ///
233 /// The `col_map` argument contains mappings to some of the columns of `get_outer`, though
234 /// perhaps not all of them. It should be used as the basis of resolving column references,
235 /// but care must be taken when adding new columns that `get_outer.arity()` is where they
236 /// will start, rather than any function of `col_map`.
237 ///
238 /// The `get_outer` expression should be a `Get` with no duplicate rows, describing the distinct
239 /// assignment of values to outer rows.
240 fn applied_to(
241 self,
242 id_gen: &mut mz_ore::id_gen::IdGen,
243 get_outer: MirRelationExpr,
244 col_map: &ColumnMap,
245 cte_map: &mut CteMap,
246 context: &Context,
247 ) -> Result<MirRelationExpr, PlanError> {
248 maybe_grow(|| {
249 use MirRelationExpr as SR;
250
251 use HirRelationExpr::*;
252
253 if let MirRelationExpr::Get { .. } = &get_outer {
254 } else {
255 panic!(
256 "get_outer: expected a MirRelationExpr::Get, found\n{}",
257 get_outer.pretty(),
258 );
259 }
260 assert_eq!(col_map.len(), get_outer.arity());
261 Ok(match self {
262 Constant { rows, typ } => {
263 // Constant expressions are not correlated with `get_outer`, and should be cross-products.
264 get_outer.product(SR::Constant {
265 rows: Ok(rows.into_iter().map(|row| (row, Diff::ONE)).collect()),
266 typ: ReprRelationType::from(&typ),
267 })
268 }
269 Get { id, typ } => match id {
270 mz_expr::Id::Local(local_id) => {
271 let cte_desc = cte_map.get(&local_id).unwrap();
272 let get_cte = SR::Get {
273 id: mz_expr::Id::Local(cte_desc.new_id.clone()),
274 typ: cte_desc.relation_type.clone(),
275 access_strategy: AccessStrategy::UnknownOrLocal,
276 };
277 if get_outer == cte_desc.outer_relation {
278 // If the CTE was applied to the same exact relation, we can safely
279 // return a `Get` relation.
280 get_cte
281 } else {
282 // Otherwise, the new outer relation may contain more columns from some
283 // intermediate scope placed between the definition of the CTE and this
284 // reference of the CTE and/or more operations applied on top of the
285 // outer relation.
286 //
287 // An example of the latter is the following query:
288 //
289 // SELECT *
290 // FROM x,
291 // LATERAL(WITH a(m) as (SELECT max(y.a) FROM y WHERE y.a < x.a)
292 // SELECT (SELECT m FROM a) FROM y) b;
293 //
294 // When the CTE is lowered, the outer relation is `Get x`. But then,
295 // the reference of the CTE is applied to `Distinct(Join(Get x, Get y), x.*)`
296 // which has the same cardinality as `Get x`.
297 //
298 // In any case, `get_outer` is guaranteed to contain the columns of the
299 // outer relation the CTE was applied to at its prefix. Since, we must
300 // return a relation containing `get_outer`'s column at the beginning,
301 // we must build a join between `get_outer` and `get_cte` on their common
302 // columns.
303 let oa = get_outer.arity();
304 let cte_outer_columns = cte_desc.relation_type.arity() - typ.arity();
305 let equivalences = (0..cte_outer_columns)
306 .map(|pos| {
307 vec![
308 MirScalarExpr::column(pos),
309 MirScalarExpr::column(pos + oa),
310 ]
311 })
312 .collect();
313
314 // Project out the second copy of the common between `get_outer` and
315 // `cte_desc.outer_relation`.
316 let projection = (0..oa)
317 .chain(oa + cte_outer_columns..oa + cte_outer_columns + typ.arity())
318 .collect_vec();
319 SR::join_scalars(vec![get_outer, get_cte], equivalences)
320 .project(projection)
321 }
322 }
323 _ => {
324 // Get statements are only to external sources, and are not correlated with `get_outer`.
325 get_outer.product(SR::Get {
326 id,
327 typ: ReprRelationType::from(&typ),
328 access_strategy: AccessStrategy::UnknownOrLocal,
329 })
330 }
331 },
332 Let {
333 name: _,
334 id,
335 value,
336 body,
337 } => {
338 let value =
339 value.applied_to(id_gen, get_outer.clone(), col_map, cte_map, context)?;
340 value.let_in(id_gen, |id_gen, get_value| {
341 let (new_id, typ) = if let MirRelationExpr::Get {
342 id: mz_expr::Id::Local(id),
343 typ,
344 ..
345 } = get_value
346 {
347 (id, typ)
348 } else {
349 panic!(
350 "get_value: expected a MirRelationExpr::Get with local Id, found\n{}",
351 get_value.pretty(),
352 );
353 };
354 // Add the information about the CTE to the map and remove it when
355 // it goes out of scope.
356 let old_value = cte_map.insert(
357 id.clone(),
358 CteDesc {
359 new_id,
360 relation_type: typ,
361 outer_relation: get_outer.clone(),
362 },
363 );
364 let body = body.applied_to(id_gen, get_outer, col_map, cte_map, context);
365 if let Some(old_value) = old_value {
366 cte_map.insert(id, old_value);
367 } else {
368 cte_map.remove(&id);
369 }
370 body
371 })?
372 }
373 LetRec {
374 limit,
375 bindings,
376 body,
377 } => {
378 let num_bindings = bindings.len();
379
380 // We use the outer type with the HIR types to form MIR CTE types.
381 let outer_column_types = get_outer.typ().column_types;
382
383 // Rename and introduce all bindings.
384 let mut shadowed_bindings = Vec::with_capacity(num_bindings);
385 let mut mir_ids = Vec::with_capacity(num_bindings);
386 for (_name, id, _value, typ) in bindings.iter() {
387 let mir_id = mz_expr::LocalId::new(id_gen.allocate_id());
388 mir_ids.push(mir_id);
389 let shadowed = cte_map.insert(
390 id.clone(),
391 CteDesc {
392 new_id: mir_id,
393 relation_type: ReprRelationType::new(
394 outer_column_types
395 .iter()
396 .cloned()
397 .chain(typ.column_types.iter().map(ReprColumnType::from))
398 .collect::<Vec<_>>(),
399 ),
400 outer_relation: get_outer.clone(),
401 },
402 );
403 shadowed_bindings.push((*id, shadowed));
404 }
405
406 let mut mir_values = Vec::with_capacity(num_bindings);
407 for (_name, _id, value, _typ) in bindings.into_iter() {
408 mir_values.push(value.applied_to(
409 id_gen,
410 get_outer.clone(),
411 col_map,
412 cte_map,
413 context,
414 )?);
415 }
416
417 let mir_body = body.applied_to(id_gen, get_outer, col_map, cte_map, context)?;
418
419 // Remove our bindings and reinstate any shadowed bindings.
420 for (id, shadowed) in shadowed_bindings {
421 if let Some(shadowed) = shadowed {
422 cte_map.insert(id, shadowed);
423 } else {
424 cte_map.remove(&id);
425 }
426 }
427
428 MirRelationExpr::LetRec {
429 ids: mir_ids,
430 values: mir_values,
431 // Copy the limit to each binding.
432 limits: repeat(limit).take(num_bindings).collect(),
433 body: Box::new(mir_body),
434 }
435 }
436 Project { input, outputs } => {
437 // Projections should be applied to the decorrelated `inner`, and to its columns,
438 // which means rebasing `outputs` to start `get_outer.arity()` columns later.
439 let input =
440 input.applied_to(id_gen, get_outer.clone(), col_map, cte_map, context)?;
441 let outputs = (0..get_outer.arity())
442 .chain(outputs.into_iter().map(|i| get_outer.arity() + i))
443 .collect::<Vec<_>>();
444 input.project(outputs)
445 }
446 Map { input, mut scalars } => {
447 // Scalar expressions may contain correlated subqueries. We must be cautious!
448
449 // We lower scalars in chunks, and must keep track of the
450 // arity of the HIR fragments lowered so far.
451 let mut lowered_arity = input.arity();
452
453 let mut input =
454 input.applied_to(id_gen, get_outer, col_map, cte_map, context)?;
455
456 // Lower subqueries in maximally sized batches, such as no subquery in the current
457 // batch depends on columns from the same batch.
458 // Note that subqueries in this projection may reference columns added by this
459 // Map operator, so we need to ensure these columns exist before lowering the
460 // subquery.
461 while !scalars.is_empty() {
462 let end_idx = scalars
463 .iter_mut()
464 .position(|s| {
465 let mut requires_nonexistent_column = false;
466 #[allow(deprecated)]
467 s.visit_columns(0, &mut |depth, col| {
468 if col.level == depth {
469 requires_nonexistent_column |= col.column >= lowered_arity
470 }
471 });
472 requires_nonexistent_column
473 })
474 .unwrap_or(scalars.len());
475 assert!(
476 end_idx > 0,
477 "a Map expression references itself or a later column; lowered_arity: {}, expressions: {:?}",
478 lowered_arity,
479 scalars
480 );
481
482 lowered_arity = lowered_arity + end_idx;
483 let scalars = scalars.drain(0..end_idx).collect_vec();
484
485 let old_arity = input.arity();
486 let (with_subqueries, subquery_map) = HirScalarExpr::lower_subqueries(
487 &scalars, id_gen, col_map, cte_map, input, context,
488 )?;
489 input = with_subqueries;
490
491 // We will proceed sequentially through the scalar expressions, for each transforming
492 // the decorrelated `input` into a relation with potentially more columns capable of
493 // addressing the needs of the scalar expression.
494 // Having done so, we add the scalar value of interest and trim off any other newly
495 // added columns.
496 //
497 // The sequential traversal is present as expressions are allowed to depend on the
498 // values of prior expressions.
499 let mut scalar_columns = Vec::new();
500 for scalar in scalars {
501 let scalar = scalar.applied_to(
502 id_gen,
503 col_map,
504 cte_map,
505 &mut input,
506 &Some(&subquery_map),
507 context,
508 )?;
509 input = input.map_one(scalar);
510 scalar_columns.push(input.arity() - 1);
511 }
512
513 // Discard any new columns added by the lowering of the scalar expressions
514 input = input.project((0..old_arity).chain(scalar_columns).collect());
515 }
516
517 input
518 }
519 CallTable { func, exprs } => {
520 // FlatMap expressions may contain correlated subqueries. Unlike Map they are not
521 // allowed to refer to the results of previous expressions, and we have a simpler
522 // implementation that appends all relevant columns first, then applies the flatmap
523 // operator to the result, then strips off any columns introduce by subqueries.
524
525 let mut input = get_outer;
526 let old_arity = input.arity();
527
528 let exprs = exprs
529 .into_iter()
530 .map(|e| e.applied_to(id_gen, col_map, cte_map, &mut input, &None, context))
531 .collect::<Result<Vec<_>, _>>()?;
532
533 let new_arity = input.arity();
534 let output_arity = func.output_arity();
535 input = input.flat_map(func, exprs);
536 if old_arity != new_arity {
537 // this means we added some columns to handle subqueries, and now we need to get rid of them
538 input = input.project(
539 (0..old_arity)
540 .chain(new_arity..new_arity + output_arity)
541 .collect(),
542 );
543 }
544 input
545 }
546 Filter { input, predicates } => {
547 // Filter expressions may contain correlated subqueries.
548 // We extend `get_outer` with sufficient values to determine the value of the predicate,
549 // then filter the results, then strip off any columns that were added for this purpose.
550 let mut input =
551 input.applied_to(id_gen, get_outer, col_map, cte_map, context)?;
552 for predicate in predicates {
553 let old_arity = input.arity();
554 let predicate = predicate
555 .applied_to(id_gen, col_map, cte_map, &mut input, &None, context)?;
556 let new_arity = input.arity();
557 input = input.filter(vec![predicate]);
558 if old_arity != new_arity {
559 // this means we added some columns to handle subqueries, and now we need to get rid of them
560 input = input.project((0..old_arity).collect());
561 }
562 }
563 input
564 }
565 Join {
566 left,
567 right,
568 on,
569 kind,
570 } if right.is_correlated() => {
571 // A correlated join is a join in which the right expression has
572 // access to the columns in the left expression. It turns out
573 // this is *exactly* our branch operator, plus some additional
574 // null handling in the case of left joins. (Right and full
575 // lateral joins are not permitted.)
576 //
577 // As with normal joins, the `on` predicate may be correlated,
578 // and we treat it as a filter that follows the branch.
579
580 assert!(kind.can_be_correlated());
581
582 let left = left.applied_to(id_gen, get_outer, col_map, cte_map, context)?;
583 left.let_in(id_gen, |id_gen, get_left| {
584 let apply_requires_distinct_outer = false;
585 let mut join = branch(
586 id_gen,
587 get_left.clone(),
588 col_map,
589 cte_map,
590 *right,
591 apply_requires_distinct_outer,
592 context,
593 |id_gen, right, get_left, col_map, cte_map, context| {
594 right.applied_to(id_gen, get_left, col_map, cte_map, context)
595 },
596 )?;
597
598 // Plan the `on` predicate.
599 let old_arity = join.arity();
600 let on =
601 on.applied_to(id_gen, col_map, cte_map, &mut join, &None, context)?;
602 join = join.filter(vec![on]);
603 let new_arity = join.arity();
604 if old_arity != new_arity {
605 // This means we added some columns to handle
606 // subqueries, and now we need to get rid of them.
607 join = join.project((0..old_arity).collect());
608 }
609
610 // If a left join, reintroduce any rows from the left that
611 // are missing, with nulls filled in for the right columns.
612 if let JoinKind::LeftOuter { .. } = kind {
613 let default = join
614 .typ()
615 .column_types
616 .into_iter()
617 .skip(get_left.arity())
618 .map(|typ| (Datum::Null, typ.scalar_type))
619 .collect();
620 get_left.lookup(id_gen, join, default)
621 } else {
622 Ok::<_, PlanError>(join)
623 }
624 })?
625 }
626 Join {
627 left,
628 right,
629 on,
630 kind,
631 } => {
632 if context.config.enable_variadic_left_join_lowering {
633 // Attempt to extract a stack of left joins.
634 if let JoinKind::LeftOuter = kind {
635 let mut rights = vec![(&*right, &on)];
636 let mut left_test = &left;
637 while let Join {
638 left,
639 right,
640 on,
641 kind: JoinKind::LeftOuter,
642 } = &**left_test
643 {
644 rights.push((&**right, on));
645 left_test = left;
646 }
647 if rights.len() > 1 {
648 // Defensively clone `cte_map` as it may be mutated.
649 let cte_map_clone = cte_map.clone();
650 if let Ok(Some(magic)) = variadic_left::attempt_left_join_magic(
651 left_test,
652 rights,
653 id_gen,
654 get_outer.clone(),
655 col_map,
656 cte_map,
657 context,
658 ) {
659 return Ok(magic);
660 } else {
661 cte_map.clone_from(&cte_map_clone);
662 }
663 }
664 }
665 }
666
667 // Both join expressions should be decorrelated, and then joined by their
668 // leading columns to form only those pairs corresponding to the same row
669 // of `get_outer`.
670 //
671 // The `on` predicate may contain correlated subqueries, and we treat it
672 // as though it was a filter, with the caveat that we also translate outer
673 // joins in this step. The post-filtration results need to be considered
674 // against the records present in the left and right (decorrelated) inputs,
675 // depending on the type of join.
676 let oa = get_outer.arity();
677 let left =
678 left.applied_to(id_gen, get_outer.clone(), col_map, cte_map, context)?;
679 let lt = left.typ().column_types.into_iter().skip(oa).collect_vec();
680 let la = lt.len();
681 left.let_in(id_gen, |id_gen, get_left| {
682 let right_col_map = col_map.enter_scope(0);
683 let right = right.applied_to(
684 id_gen,
685 get_outer.clone(),
686 &right_col_map,
687 cte_map,
688 context,
689 )?;
690 let rt = right.typ().column_types.into_iter().skip(oa).collect_vec();
691 let ra = rt.len();
692 right.let_in(id_gen, |id_gen, get_right| {
693 let mut product = SR::join(
694 vec![get_left.clone(), get_right.clone()],
695 (0..oa).map(|i| vec![(0, i), (1, i)]).collect(),
696 )
697 // Project away the repeated copy of get_outer's columns.
698 .project(
699 (0..(oa + la))
700 .chain((oa + la + oa)..(oa + la + oa + ra))
701 .collect(),
702 );
703
704 // Decorrelate and lower the `on` clause.
705 let on = on.applied_to(
706 id_gen,
707 col_map,
708 cte_map,
709 &mut product,
710 &None,
711 context,
712 )?;
713 // Collect the types of all subqueries appearing in
714 // the `on` clause. The subquery results were
715 // appended to `product` in the `on.applied_to(...)`
716 // call above.
717 let on_subquery_types = product
718 .typ()
719 .column_types
720 .drain(oa + la + ra..)
721 .collect_vec();
722 // Remember if `on` had any subqueries.
723 let on_has_subqueries = !on_subquery_types.is_empty();
724
725 // Attempt an efficient equijoin implementation, in which outer joins are
726 // more efficiently rendered than in general. This can return `None` if
727 // such a plan is not possible, for example if `on` does not describe an
728 // equijoin between columns of `left` and `right`.
729 if kind != JoinKind::Inner {
730 if let Some(joined) = attempt_outer_equijoin(
731 get_left.clone(),
732 get_right.clone(),
733 on.clone(),
734 on_subquery_types,
735 kind.clone(),
736 oa,
737 id_gen,
738 context,
739 )? {
740 if let Some(metrics) = context.metrics {
741 metrics.inc_outer_join_lowering("equi");
742 }
743 return Ok(joined);
744 }
745 }
746
747 // Otherwise, perform a more general join.
748 if let Some(metrics) = context.metrics {
749 metrics.inc_outer_join_lowering("general");
750 }
751 let mut join = product.filter(vec![on]);
752 if on_has_subqueries {
753 // This means that `on.applied_to(...)` appended
754 // some columns to handle subqueries, and now we
755 // need to get rid of them.
756 join = join.project((0..oa + la + ra).collect());
757 }
758 join.let_in(id_gen, |id_gen, get_join| {
759 let mut result = get_join.clone();
760 if let JoinKind::LeftOuter { .. } | JoinKind::FullOuter { .. } =
761 kind
762 {
763 let left_outer = get_left.clone().anti_lookup::<PlanError>(
764 id_gen,
765 get_join.clone(),
766 rt.into_iter()
767 .map(|typ| (Datum::Null, typ.scalar_type))
768 .collect(),
769 )?;
770 result = result.union(left_outer);
771 }
772 if let JoinKind::RightOuter | JoinKind::FullOuter = kind {
773 let right_outer = get_right
774 .clone()
775 .anti_lookup::<PlanError>(
776 id_gen,
777 get_join
778 // need to swap left and right to make the anti_lookup work
779 .project(
780 (0..oa)
781 .chain((oa + la)..(oa + la + ra))
782 .chain((oa)..(oa + la))
783 .collect(),
784 ),
785 lt.into_iter()
786 .map(|typ| (Datum::Null, typ.scalar_type))
787 .collect(),
788 )?
789 // swap left and right back again
790 .project(
791 (0..oa)
792 .chain((oa + ra)..(oa + ra + la))
793 .chain((oa)..(oa + ra))
794 .collect(),
795 );
796 result = result.union(right_outer);
797 }
798 Ok::<MirRelationExpr, PlanError>(result)
799 })
800 })
801 })?
802 }
803 Union { base, inputs } => {
804 // Union is uncomplicated.
805 SR::Union {
806 base: Box::new(base.applied_to(
807 id_gen,
808 get_outer.clone(),
809 col_map,
810 cte_map,
811 context,
812 )?),
813 inputs: inputs
814 .into_iter()
815 .map(|input| {
816 input.applied_to(
817 id_gen,
818 get_outer.clone(),
819 col_map,
820 cte_map,
821 context,
822 )
823 })
824 .collect::<Result<Vec<_>, _>>()?,
825 }
826 }
827 Reduce {
828 input,
829 group_key,
830 aggregates,
831 expected_group_size,
832 } => {
833 // Reduce may contain expressions with correlated subqueries.
834 // In addition, here an empty reduction key signifies that we need to supply default values
835 // in the case that there are no results (as in a SQL aggregation without an explicit GROUP BY).
836 let mut input =
837 input.applied_to(id_gen, get_outer.clone(), col_map, cte_map, context)?;
838 let applied_group_key = (0..get_outer.arity())
839 .chain(group_key.iter().map(|i| get_outer.arity() + i))
840 .collect();
841 let applied_aggregates = aggregates
842 .into_iter()
843 .map(|aggregate| {
844 aggregate.applied_to(id_gen, col_map, cte_map, &mut input, context)
845 })
846 .collect::<Result<Vec<_>, _>>()?;
847 let input_type = input.typ();
848 let default = applied_aggregates
849 .iter()
850 .map(|agg| {
851 (
852 agg.func.default(),
853 agg.typ(&input_type.column_types).scalar_type,
854 )
855 })
856 .collect();
857 // NOTE we don't need to remove any extra columns from aggregate.applied_to above because the reduce will do that anyway
858 let mut reduced =
859 input.reduce(applied_group_key, applied_aggregates, expected_group_size);
860
861 // Introduce default values in the case the group key is empty.
862 if group_key.is_empty() {
863 reduced = get_outer.lookup::<PlanError>(id_gen, reduced, default)?;
864 }
865 reduced
866 }
867 Distinct { input } => {
868 // Distinct is uncomplicated.
869 input
870 .applied_to(id_gen, get_outer, col_map, cte_map, context)?
871 .distinct()
872 }
873 TopK {
874 input,
875 group_key,
876 order_key,
877 limit,
878 offset,
879 expected_group_size,
880 } => {
881 // TopK is uncomplicated, except that we must group by the columns of `get_outer` as well.
882 let mut input =
883 input.applied_to(id_gen, get_outer.clone(), col_map, cte_map, context)?;
884 let mut applied_group_key: Vec<_> = (0..get_outer.arity())
885 .chain(group_key.iter().map(|i| get_outer.arity() + i))
886 .collect();
887 let applied_order_key = order_key
888 .iter()
889 .map(|column_order| ColumnOrder {
890 column: column_order.column + get_outer.arity(),
891 desc: column_order.desc,
892 nulls_last: column_order.nulls_last,
893 })
894 .collect();
895
896 let old_arity = input.arity();
897
898 // Lower `limit`, which may introduce new columns if is a correlated subquery.
899 let mut limit_mir = None;
900 if let Some(limit) = limit {
901 limit_mir = Some(
902 limit
903 .applied_to(id_gen, col_map, cte_map, &mut input, &None, context)?,
904 );
905 }
906
907 let new_arity = input.arity();
908 // Extend the key to contain any new columns.
909 applied_group_key.extend(old_arity..new_arity);
910
911 let offset = offset
912 .try_into_literal_int64()
913 .expect("Should be a Literal by this time")
914 .try_into()
915 .expect("Should have checked non-negativity of OFFSET clause already");
916 let mut result = input.top_k(
917 applied_group_key,
918 applied_order_key,
919 limit_mir,
920 offset,
921 expected_group_size,
922 );
923
924 // If new columns were added for `limit` we must remove them.
925 if old_arity != new_arity {
926 result = result.project((0..old_arity).collect());
927 }
928
929 result
930 }
931 Negate { input } => {
932 // Negate is uncomplicated.
933 input
934 .applied_to(id_gen, get_outer, col_map, cte_map, context)?
935 .negate()
936 }
937 Threshold { input } => {
938 // Threshold is uncomplicated.
939 input
940 .applied_to(id_gen, get_outer, col_map, cte_map, context)?
941 .threshold()
942 }
943 })
944 })
945 }
946}
947
948impl HirScalarExpr {
949 /// Rewrite `self` into a `mz_expr::ScalarExpr` which can be applied to the modified `inner`.
950 ///
951 /// This method is responsible for decorrelating subqueries in `self` by introducing further columns
952 /// to `inner`, and rewriting `self` to refer to its physical columns (specified by `usize` positions).
953 /// The most complicated logic is for the scalar expressions that involve subqueries, each of which are
954 /// documented in more detail closer to their logic.
955 ///
956 /// This process presumes that `inner` is the result of decorrelation, meaning its first several columns
957 /// may be inherited from outer relations. The `col_map` column map should provide specific offsets where
958 /// each of these references can be found.
959 fn applied_to(
960 self,
961 id_gen: &mut mz_ore::id_gen::IdGen,
962 col_map: &ColumnMap,
963 cte_map: &mut CteMap,
964 inner: &mut MirRelationExpr,
965 subquery_map: &Option<&BTreeMap<HirScalarExpr, usize>>,
966 context: &Context,
967 ) -> Result<MirScalarExpr, PlanError> {
968 maybe_grow(|| {
969 use MirScalarExpr as SS;
970
971 use HirScalarExpr::*;
972
973 if let Some(subquery_map) = subquery_map {
974 if let Some(col) = subquery_map.get(&self) {
975 return Ok(SS::column(*col));
976 }
977 }
978
979 Ok::<MirScalarExpr, PlanError>(match self {
980 Column(col_ref, name) => SS::Column(col_map.get(&col_ref), name),
981 Literal(row, typ, _name) => SS::Literal(Ok(row), ReprColumnType::from(&typ)),
982 Parameter(_, _name) => {
983 panic!("cannot decorrelate expression with unbound parameters")
984 }
985 CallUnmaterializable(func, _name) => SS::CallUnmaterializable(func),
986 CallUnary {
987 func: func::UnaryFunc::CastVarCharToString(_),
988 expr,
989 name: _,
990 } if context.config.enable_cast_elimination => {
991 expr.applied_to(id_gen, col_map, cte_map, inner, subquery_map, context)?
992 }
993 CallUnary {
994 func,
995 expr,
996 name: _,
997 } => SS::CallUnary {
998 func,
999 expr: Box::new(expr.applied_to(
1000 id_gen,
1001 col_map,
1002 cte_map,
1003 inner,
1004 subquery_map,
1005 context,
1006 )?),
1007 },
1008 CallBinary {
1009 func,
1010 expr1,
1011 expr2,
1012 name: _,
1013 } => SS::CallBinary {
1014 func,
1015 expr1: Box::new(expr1.applied_to(
1016 id_gen,
1017 col_map,
1018 cte_map,
1019 inner,
1020 subquery_map,
1021 context,
1022 )?),
1023 expr2: Box::new(expr2.applied_to(
1024 id_gen,
1025 col_map,
1026 cte_map,
1027 inner,
1028 subquery_map,
1029 context,
1030 )?),
1031 },
1032 CallVariadic {
1033 func,
1034 exprs,
1035 name: _,
1036 } => SS::call_variadic(
1037 func,
1038 exprs
1039 .into_iter()
1040 .map(|expr| {
1041 expr.applied_to(id_gen, col_map, cte_map, inner, subquery_map, context)
1042 })
1043 .collect::<Result<_, _>>()?,
1044 ),
1045 If {
1046 cond,
1047 then,
1048 els,
1049 name,
1050 } => {
1051 // The `If` case is complicated by the fact that we do not want to
1052 // apply the `then` or `else` logic to tuples that respectively do
1053 // not or do pass the `cond` test. Our strategy is to independently
1054 // decorrelate the `then` and `else` logic, and apply each to tuples
1055 // that respectively pass and do not pass the `cond` logic (which is
1056 // executed, and so decorrelated, for all tuples).
1057 //
1058 // Informally, we turn the `if` statement into:
1059 //
1060 // let then_case = inner.filter(cond).map(then);
1061 // let else_case = inner.filter(!cond).map(else);
1062 // return then_case.concat(else_case);
1063 //
1064 // We only require this if either expression would result in any
1065 // computation beyond the expr itself, which we will interpret as
1066 // "introduces additional columns". In the absence of correlation,
1067 // we should just retain a `ScalarExpr::If` expression; the inverse
1068 // transformation as above is complicated to recover after the fact,
1069 // and we would benefit from not introducing the complexity.
1070
1071 let inner_arity = inner.arity();
1072 let cond_expr =
1073 cond.applied_to(id_gen, col_map, cte_map, inner, subquery_map, context)?;
1074
1075 // Defensive copies, in case we mangle these in decorrelation.
1076 let inner_clone = inner.clone();
1077 let then_clone = then.clone();
1078 let else_clone = els.clone();
1079
1080 let cond_arity = inner.arity();
1081 let then_expr =
1082 then.applied_to(id_gen, col_map, cte_map, inner, subquery_map, context)?;
1083 let else_expr =
1084 els.applied_to(id_gen, col_map, cte_map, inner, subquery_map, context)?;
1085
1086 if cond_arity == inner.arity() {
1087 // If no additional columns were added, we simply return the
1088 // `If` variant with the updated expressions.
1089 SS::If {
1090 cond: Box::new(cond_expr),
1091 then: Box::new(then_expr),
1092 els: Box::new(else_expr),
1093 }
1094 } else {
1095 // If columns were added, we need a more careful approach, as
1096 // described above. First, we need to de-correlate each of
1097 // the two expressions independently, and apply their cases
1098 // as `MirRelationExpr::Map` operations.
1099
1100 *inner = inner_clone.let_in(id_gen, |id_gen, get_inner| {
1101 // Restrict to records satisfying `cond_expr` and apply `then` as a map.
1102 let mut then_inner = get_inner.clone().filter(vec![cond_expr.clone()]);
1103 let then_expr = then_clone.applied_to(
1104 id_gen,
1105 col_map,
1106 cte_map,
1107 &mut then_inner,
1108 subquery_map,
1109 context,
1110 )?;
1111 let then_arity = then_inner.arity();
1112 then_inner = then_inner
1113 .map_one(then_expr)
1114 .project((0..inner_arity).chain(Some(then_arity)).collect());
1115
1116 // Restrict to records not satisfying `cond_expr` and apply `els` as a map.
1117 let mut else_inner = get_inner.filter(vec![SS::call_variadic(
1118 variadic::Or,
1119 vec![
1120 cond_expr.clone().call_binary(SS::literal_false(), func::Eq),
1121 cond_expr.clone().call_is_null(),
1122 ],
1123 )]);
1124 let else_expr = else_clone.applied_to(
1125 id_gen,
1126 col_map,
1127 cte_map,
1128 &mut else_inner,
1129 subquery_map,
1130 context,
1131 )?;
1132 let else_arity = else_inner.arity();
1133 else_inner = else_inner
1134 .map_one(else_expr)
1135 .project((0..inner_arity).chain(Some(else_arity)).collect());
1136
1137 // concatenate the two results.
1138 Ok::<MirRelationExpr, PlanError>(then_inner.union(else_inner))
1139 })?;
1140
1141 SS::Column(inner_arity, name)
1142 }
1143 }
1144
1145 // Subqueries!
1146 // These are surprisingly subtle. Things to be careful of:
1147
1148 // Anything in the subquery that cares about row counts (Reduce/Distinct/Negate/Threshold) must not:
1149 // * change the row counts of the outer query
1150 // * accidentally compute its own value using the row counts of the outer query
1151 // Use `branch` to calculate the subquery once for each __distinct__ key in the outer
1152 // query and then join the answers back on to the original rows of the outer query.
1153
1154 // When the subquery would return 0 rows for some row in the outer query, `subquery.applied_to(get_inner)` will not have any corresponding row.
1155 // Use `lookup` if you need to add default values for cases when the subquery returns 0 rows.
1156 Exists(expr, name) => {
1157 let apply_requires_distinct_outer = true;
1158 *inner = apply_existential_subquery(
1159 id_gen,
1160 inner.take_dangerous(),
1161 col_map,
1162 cte_map,
1163 *expr,
1164 apply_requires_distinct_outer,
1165 context,
1166 )?;
1167 SS::Column(inner.arity() - 1, name)
1168 }
1169
1170 Select(expr, name) => {
1171 let apply_requires_distinct_outer = true;
1172 *inner = apply_scalar_subquery(
1173 id_gen,
1174 inner.take_dangerous(),
1175 col_map,
1176 cte_map,
1177 *expr,
1178 apply_requires_distinct_outer,
1179 context,
1180 )?;
1181 SS::Column(inner.arity() - 1, name)
1182 }
1183 Windowing(expr, _name) => {
1184 let partition_by = expr.partition_by;
1185 let order_by = expr.order_by;
1186
1187 // argument lowering for scalar window functions
1188 // (We need to specify the & _ in the arguments because of this problem:
1189 // https://users.rust-lang.org/t/the-implementation-of-fnonce-is-not-general-enough/72141/3 )
1190 let scalar_lower_args =
1191 |_id_gen: &mut _,
1192 _col_map: &_,
1193 _cte_map: &mut _,
1194 _get_inner: &mut _,
1195 _subquery_map: &Option<&_>,
1196 order_by_mir: Vec<MirScalarExpr>,
1197 original_row_record,
1198 original_row_record_type: SqlScalarType| {
1199 let agg_input = MirScalarExpr::call_variadic(
1200 variadic::ListCreate {
1201 elem_type: original_row_record_type.clone(),
1202 },
1203 vec![original_row_record],
1204 );
1205 let mut agg_input = vec![agg_input];
1206 agg_input.extend(order_by_mir.clone());
1207 let agg_input = MirScalarExpr::call_variadic(
1208 variadic::RecordCreate {
1209 field_names: (0..agg_input.len())
1210 .map(|_| ColumnName::from(UNKNOWN_COLUMN_NAME))
1211 .collect_vec(),
1212 },
1213 agg_input,
1214 );
1215 let list_type = SqlScalarType::List {
1216 element_type: Box::new(original_row_record_type),
1217 custom_id: None,
1218 };
1219 let agg_input_type = SqlScalarType::Record {
1220 fields: std::iter::once(&list_type)
1221 .map(|t| {
1222 (
1223 ColumnName::from(UNKNOWN_COLUMN_NAME),
1224 t.clone().nullable(false),
1225 )
1226 })
1227 .collect(),
1228 custom_id: None,
1229 }
1230 .nullable(false);
1231
1232 Ok((agg_input, agg_input_type))
1233 };
1234
1235 // argument lowering for value window functions and aggregate window functions
1236 let value_or_aggr_lower_args = |hir_encoded_args: Box<HirScalarExpr>| {
1237 |id_gen: &mut _,
1238 col_map: &_,
1239 cte_map: &mut _,
1240 get_inner: &mut _,
1241 subquery_map: &Option<&_>,
1242 order_by_mir: Vec<MirScalarExpr>,
1243 original_row_record,
1244 original_row_record_type| {
1245 // Creates [((OriginalRow, EncodedArgs), OrderByExprs...)]
1246
1247 // Compute the encoded args for all rows
1248 let mir_encoded_args = hir_encoded_args.applied_to(
1249 id_gen,
1250 col_map,
1251 cte_map,
1252 get_inner,
1253 subquery_map,
1254 context,
1255 )?;
1256 let mir_encoded_args_type = mir_encoded_args
1257 .sql_typ(&get_inner.sql_typ().column_types)
1258 .scalar_type;
1259
1260 // Build a new record that has two fields:
1261 // 1. the original row in a record
1262 // 2. the encoded args (which can be either a single value, or a record
1263 // if the window function has multiple arguments, such as `lag`)
1264 let fn_input_record_fields: Box<[_]> =
1265 [original_row_record_type, mir_encoded_args_type]
1266 .iter()
1267 .map(|t| {
1268 (
1269 ColumnName::from(UNKNOWN_COLUMN_NAME),
1270 t.clone().nullable(false),
1271 )
1272 })
1273 .collect();
1274 let fn_input_record = MirScalarExpr::call_variadic(
1275 variadic::RecordCreate {
1276 field_names: fn_input_record_fields
1277 .iter()
1278 .map(|(n, _)| n.clone())
1279 .collect_vec(),
1280 },
1281 vec![original_row_record, mir_encoded_args],
1282 );
1283 let fn_input_record_type = SqlScalarType::Record {
1284 fields: fn_input_record_fields,
1285 custom_id: None,
1286 }
1287 .nullable(false);
1288
1289 // Build a new record with the record above + the ORDER BY exprs
1290 // This follows the standard encoding of ORDER BY exprs used by aggregate functions
1291 let mut agg_input = vec![fn_input_record];
1292 agg_input.extend(order_by_mir.clone());
1293 let agg_input = MirScalarExpr::call_variadic(
1294 variadic::RecordCreate {
1295 field_names: (0..agg_input.len())
1296 .map(|_| ColumnName::from(UNKNOWN_COLUMN_NAME))
1297 .collect_vec(),
1298 },
1299 agg_input,
1300 );
1301
1302 let agg_input_type = SqlScalarType::Record {
1303 fields: [(
1304 ColumnName::from(UNKNOWN_COLUMN_NAME),
1305 fn_input_record_type.nullable(false),
1306 )]
1307 .into(),
1308 custom_id: None,
1309 }
1310 .nullable(false);
1311
1312 Ok((agg_input, agg_input_type))
1313 }
1314 };
1315
1316 match expr.func {
1317 WindowExprType::Scalar(scalar_window_expr) => {
1318 let mir_aggr_func = scalar_window_expr.into_expr();
1319 Self::window_func_applied_to(
1320 id_gen,
1321 col_map,
1322 cte_map,
1323 inner,
1324 subquery_map,
1325 partition_by,
1326 order_by,
1327 mir_aggr_func,
1328 scalar_lower_args,
1329 context,
1330 )?
1331 }
1332 WindowExprType::Value(value_window_expr) => {
1333 let (hir_encoded_args, mir_aggr_func) = value_window_expr.into_expr();
1334
1335 Self::window_func_applied_to(
1336 id_gen,
1337 col_map,
1338 cte_map,
1339 inner,
1340 subquery_map,
1341 partition_by,
1342 order_by,
1343 mir_aggr_func,
1344 value_or_aggr_lower_args(hir_encoded_args),
1345 context,
1346 )?
1347 }
1348 WindowExprType::Aggregate(aggr_window_expr) => {
1349 let (hir_encoded_args, mir_aggr_func) = aggr_window_expr.into_expr();
1350
1351 Self::window_func_applied_to(
1352 id_gen,
1353 col_map,
1354 cte_map,
1355 inner,
1356 subquery_map,
1357 partition_by,
1358 order_by,
1359 mir_aggr_func,
1360 value_or_aggr_lower_args(hir_encoded_args),
1361 context,
1362 )?
1363 }
1364 }
1365 }
1366 })
1367 })
1368 }
1369
1370 fn window_func_applied_to<F>(
1371 id_gen: &mut mz_ore::id_gen::IdGen,
1372 col_map: &ColumnMap,
1373 cte_map: &mut CteMap,
1374 inner: &mut MirRelationExpr,
1375 subquery_map: &Option<&BTreeMap<HirScalarExpr, usize>>,
1376 partition_by: Vec<HirScalarExpr>,
1377 order_by: Vec<HirScalarExpr>,
1378 mir_aggr_func: AggregateFunc,
1379 lower_args: F,
1380 context: &Context,
1381 ) -> Result<MirScalarExpr, PlanError>
1382 where
1383 F: FnOnce(
1384 &mut mz_ore::id_gen::IdGen,
1385 &ColumnMap,
1386 &mut CteMap,
1387 &mut MirRelationExpr,
1388 &Option<&BTreeMap<HirScalarExpr, usize>>,
1389 Vec<MirScalarExpr>,
1390 MirScalarExpr,
1391 SqlScalarType,
1392 ) -> Result<(MirScalarExpr, SqlColumnType), PlanError>,
1393 {
1394 // Example MIRs for a window function (specifically, a window aggregation):
1395 //
1396 // CREATE TABLE t7(x INT, y INT);
1397 //
1398 // explain decorrelated plan for select sum(x*y) over (partition by x+y order by x-y, x/y) from t7;
1399 //
1400 // Decorrelated Plan
1401 // Project (#3)
1402 // Map (#2)
1403 // Project (#3..=#5)
1404 // Map (record_get[0](record_get[1](#2)), record_get[1](record_get[1](#2)), record_get[0](#2))
1405 // FlatMap unnest_list(#1)
1406 // Reduce group_by=[#2] aggregates=[window_agg[sum order_by=[#0 asc nulls_last, #1 asc nulls_last]](row(row(row(#0, #1), (#0 * #1)), (#0 - #1), (#0 / #1)))]
1407 // Map ((#0 + #1))
1408 // CrossJoin
1409 // Constant
1410 // - ()
1411 // Get materialize.public.t7
1412 //
1413 // The same query after optimizations:
1414 //
1415 // explain select sum(x*y) over (partition by x+y order by x-y, x/y) from t7;
1416 //
1417 // Optimized Plan
1418 // Explained Query:
1419 // Project (#2)
1420 // Map (record_get[0](#1))
1421 // FlatMap unnest_list(#0)
1422 // Project (#1)
1423 // Reduce group_by=[(#0 + #1)] aggregates=[window_agg[sum order_by=[#0 asc nulls_last, #1 asc nulls_last]](row(row(row(#0, #1), (#0 * #1)), (#0 - #1), (#0 / #1)))]
1424 // ReadStorage materialize.public.t7
1425 //
1426 // The `row(row(row(...), ...), ...)` stuff means the following:
1427 // `row(row(row(<original row>), <arguments to window function>), <order by values>...)`
1428 // - The <arguments to window function> can be either a single value or itself a
1429 // `row` if there are multiple arguments.
1430 // - The <order by values> are _not_ wrapped in a `row`, even if there are more than one
1431 // ORDER BY columns.
1432 // - The <original row> currently always captures the entire original row. This should
1433 // improve when we make `ProjectionPushdown` smarter, see
1434 // https://github.com/MaterializeInc/database-issues/issues/5090
1435 //
1436 // TODO:
1437 // We should probably introduce some dedicated Datum constructor functions instead of `row`
1438 // to make MIR plans and MIR construction/manipulation code more readable. Additionally, we
1439 // might even introduce dedicated Datum enum variants, so that the rendering code also
1440 // becomes more readable (and possibly slightly more performant).
1441
1442 *inner = inner
1443 .take_dangerous()
1444 .let_in(id_gen, |id_gen, mut get_inner| {
1445 let order_by_mir = order_by
1446 .into_iter()
1447 .map(|o| {
1448 o.applied_to(
1449 id_gen,
1450 col_map,
1451 cte_map,
1452 &mut get_inner,
1453 subquery_map,
1454 context,
1455 )
1456 })
1457 .collect::<Result<Vec<_>, _>>()?;
1458
1459 // Record input arity here so that any group_keys that need to mutate get_inner
1460 // don't add those columns to the aggregate input.
1461 let input_type = get_inner.sql_typ();
1462 let input_arity = input_type.arity();
1463 // The reduction that computes the window function must be keyed on the columns
1464 // from the outer context, plus the expressions in the partition key. The current
1465 // subquery will be 'executed' for every distinct row from the outer context so
1466 // by putting the outer columns in the grouping key we isolate each re-execution.
1467 let mut group_key = col_map
1468 .inner
1469 .iter()
1470 .map(|(_, outer_col)| *outer_col)
1471 .sorted()
1472 .collect_vec();
1473 for p in partition_by {
1474 let key = p.applied_to(
1475 id_gen,
1476 col_map,
1477 cte_map,
1478 &mut get_inner,
1479 subquery_map,
1480 context,
1481 )?;
1482 if let MirScalarExpr::Column(c, _name) = key {
1483 group_key.push(c);
1484 } else {
1485 get_inner = get_inner.map_one(key);
1486 group_key.push(get_inner.arity() - 1);
1487 }
1488 }
1489
1490 get_inner.let_in(id_gen, |id_gen, mut get_inner| {
1491 // Original columns of the relation
1492 let fields: Box<_> = input_type
1493 .column_types
1494 .iter()
1495 .take(input_arity)
1496 .map(|t| (ColumnName::from(UNKNOWN_COLUMN_NAME), t.clone()))
1497 .collect();
1498
1499 // Original row made into a record
1500 let original_row_record = MirScalarExpr::call_variadic(
1501 variadic::RecordCreate {
1502 field_names: fields.iter().map(|(name, _)| name.clone()).collect_vec(),
1503 },
1504 (0..input_arity).map(MirScalarExpr::column).collect_vec(),
1505 );
1506 let original_row_record_type = SqlScalarType::Record {
1507 fields,
1508 custom_id: None,
1509 };
1510
1511 let (agg_input, agg_input_type) = lower_args(
1512 id_gen,
1513 col_map,
1514 cte_map,
1515 &mut get_inner,
1516 subquery_map,
1517 order_by_mir,
1518 original_row_record,
1519 original_row_record_type,
1520 )?;
1521
1522 let aggregate = mz_expr::AggregateExpr {
1523 func: mir_aggr_func,
1524 expr: agg_input,
1525 distinct: false,
1526 };
1527
1528 // Actually call reduce with the window function
1529 // The output of the aggregation function should be a list of tuples that has
1530 // the result in the first position, and the original row in the second position
1531 let mut reduce = get_inner
1532 .reduce(group_key.clone(), vec![aggregate.clone()], None)
1533 .flat_map(
1534 mz_expr::TableFunc::UnnestList {
1535 el_typ: aggregate
1536 .func
1537 .output_sql_type(agg_input_type)
1538 .scalar_type
1539 .unwrap_list_element_type()
1540 .clone(),
1541 },
1542 vec![MirScalarExpr::column(group_key.len())],
1543 );
1544 let record_col = reduce.arity() - 1;
1545
1546 // Unpack the record output by the window function
1547 for c in 0..input_arity {
1548 reduce = reduce.take_dangerous().map_one(MirScalarExpr::CallUnary {
1549 func: mz_expr::UnaryFunc::RecordGet(mz_expr::func::RecordGet(c)),
1550 expr: Box::new(MirScalarExpr::CallUnary {
1551 func: mz_expr::UnaryFunc::RecordGet(mz_expr::func::RecordGet(1)),
1552 expr: Box::new(MirScalarExpr::column(record_col)),
1553 }),
1554 });
1555 }
1556
1557 // Append the column with the result of the window function.
1558 reduce = reduce.take_dangerous().map_one(MirScalarExpr::CallUnary {
1559 func: mz_expr::UnaryFunc::RecordGet(mz_expr::func::RecordGet(0)),
1560 expr: Box::new(MirScalarExpr::column(record_col)),
1561 });
1562
1563 let agg_col = record_col + 1 + input_arity;
1564 Ok::<_, PlanError>(reduce.project((record_col + 1..agg_col + 1).collect_vec()))
1565 })
1566 })?;
1567 Ok(MirScalarExpr::column(inner.arity() - 1))
1568 }
1569
1570 /// Applies the subqueries in the given list of scalar expressions to every distinct
1571 /// value of the given relation and returns a join of the given relation with all
1572 /// the subqueries found, and the mapping of scalar expressions with columns projected
1573 /// by the returned join that will hold their results.
1574 fn lower_subqueries(
1575 exprs: &[Self],
1576 id_gen: &mut mz_ore::id_gen::IdGen,
1577 col_map: &ColumnMap,
1578 cte_map: &mut CteMap,
1579 inner: MirRelationExpr,
1580 context: &Context,
1581 ) -> Result<(MirRelationExpr, BTreeMap<HirScalarExpr, usize>), PlanError> {
1582 let mut subquery_map = BTreeMap::new();
1583 let output = inner.let_in(id_gen, |id_gen, get_inner| {
1584 let mut subqueries = Vec::new();
1585 let distinct_inner = get_inner.clone().distinct();
1586 for expr in exprs.iter() {
1587 expr.visit_pre_post(
1588 &mut |e| match e {
1589 // For simplicity, subqueries within a conditional statement will be
1590 // lowered when lowering the conditional expression.
1591 HirScalarExpr::If { .. } => Some(vec![]),
1592 _ => None,
1593 },
1594 &mut |e| match e {
1595 HirScalarExpr::Select(expr, _name) => {
1596 let apply_requires_distinct_outer = false;
1597 let subquery = apply_scalar_subquery(
1598 id_gen,
1599 distinct_inner.clone(),
1600 col_map,
1601 cte_map,
1602 (**expr).clone(),
1603 apply_requires_distinct_outer,
1604 context,
1605 )
1606 .unwrap();
1607
1608 subqueries.push((e.clone(), subquery));
1609 }
1610 HirScalarExpr::Exists(expr, _name) => {
1611 let apply_requires_distinct_outer = false;
1612 let subquery = apply_existential_subquery(
1613 id_gen,
1614 distinct_inner.clone(),
1615 col_map,
1616 cte_map,
1617 (**expr).clone(),
1618 apply_requires_distinct_outer,
1619 context,
1620 )
1621 .unwrap();
1622 subqueries.push((e.clone(), subquery));
1623 }
1624 _ => {}
1625 },
1626 )?;
1627 }
1628
1629 if subqueries.is_empty() {
1630 Ok::<MirRelationExpr, PlanError>(get_inner)
1631 } else {
1632 let inner_arity = get_inner.arity();
1633 let mut total_arity = inner_arity;
1634 let mut join_inputs = vec![get_inner];
1635 let mut join_input_arities = vec![inner_arity];
1636 for (expr, subquery) in subqueries.into_iter() {
1637 // Avoid lowering duplicated subqueries
1638 if !subquery_map.contains_key(&expr) {
1639 let subquery_arity = subquery.arity();
1640 assert_eq!(subquery_arity, inner_arity + 1);
1641 join_inputs.push(subquery);
1642 join_input_arities.push(subquery_arity);
1643 total_arity += subquery_arity;
1644
1645 // Column with the value of the subquery
1646 subquery_map.insert(expr, total_arity - 1);
1647 }
1648 }
1649 // Each subquery projects all the columns of the outer context (distinct_inner)
1650 // plus 1 column, containing the result of the subquery. Those columns must be
1651 // joined with the outer/main relation (get_inner).
1652 let input_mapper =
1653 mz_expr::JoinInputMapper::new_from_input_arities(join_input_arities);
1654 let equivalences = (0..inner_arity)
1655 .map(|col| {
1656 join_inputs
1657 .iter()
1658 .enumerate()
1659 .map(|(input, _)| {
1660 MirScalarExpr::column(input_mapper.map_column_to_global(col, input))
1661 })
1662 .collect_vec()
1663 })
1664 .collect_vec();
1665 Ok(MirRelationExpr::join_scalars(join_inputs, equivalences))
1666 }
1667 })?;
1668 Ok((output, subquery_map))
1669 }
1670
1671 /// Rewrites `self` into a `mz_expr::ScalarExpr`.
1672 ///
1673 /// Returns an _internal_ error if the expression contains
1674 /// - a subquery
1675 /// - a column reference to an outer level
1676 /// - a parameter
1677 /// - a window function call
1678 ///
1679 /// Should succeed if [`HirScalarExpr::is_constant`] would return true on `self`.
1680 ///
1681 /// Set `enable_cast_elimination` to remove casts that are noops in MIR.
1682 pub fn lower_uncorrelated<C: Into<Config>>(
1683 self,
1684 config: C,
1685 ) -> Result<MirScalarExpr, PlanError> {
1686 let config = config.into();
1687
1688 use MirScalarExpr as SS;
1689
1690 use HirScalarExpr::*;
1691
1692 Ok(match self {
1693 Column(ColumnRef { level: 0, column }, name) => SS::Column(column, name),
1694 Literal(datum, typ, _name) => SS::Literal(Ok(datum), ReprColumnType::from(&typ)),
1695 CallUnmaterializable(func, _name) => SS::CallUnmaterializable(func),
1696 CallUnary {
1697 func: func::UnaryFunc::CastVarCharToString(_),
1698 expr,
1699 name: _,
1700 } if config.enable_cast_elimination => expr.lower_uncorrelated(config)?,
1701 CallUnary {
1702 func,
1703 expr,
1704 name: _,
1705 } => SS::CallUnary {
1706 func,
1707 expr: Box::new(expr.lower_uncorrelated(config)?),
1708 },
1709 CallBinary {
1710 func,
1711 expr1,
1712 expr2,
1713 name: _,
1714 } => SS::CallBinary {
1715 func,
1716 expr1: Box::new(expr1.lower_uncorrelated(config)?),
1717 expr2: Box::new(expr2.lower_uncorrelated(config)?),
1718 },
1719 CallVariadic {
1720 func,
1721 exprs,
1722 name: _,
1723 } => SS::call_variadic(
1724 func,
1725 exprs
1726 .into_iter()
1727 .map(|expr| expr.lower_uncorrelated(config))
1728 .collect::<Result<_, _>>()?,
1729 ),
1730 If {
1731 cond,
1732 then,
1733 els,
1734 name: _,
1735 } => SS::If {
1736 cond: Box::new(cond.lower_uncorrelated(config)?),
1737 then: Box::new(then.lower_uncorrelated(config)?),
1738 els: Box::new(els.lower_uncorrelated(config)?),
1739 },
1740 Select { .. } | Exists { .. } | Parameter(..) | Column(..) | Windowing(..) => {
1741 sql_bail!(
1742 "Internal error: unexpected HirScalarExpr in lower_uncorrelated: {:?}",
1743 self
1744 );
1745 }
1746 })
1747 }
1748}
1749
1750/// Prepare to apply `inner` to `outer`. Note that `inner` is a correlated (SQL)
1751/// expression, while `outer` is a non-correlated (dataflow) expression. `inner`
1752/// will, in effect, be executed once for every distinct row in `outer`, and the
1753/// results will be joined with `outer`. Note that columns in `outer` that are
1754/// not depended upon by `inner` are thrown away before the distinct, so that we
1755/// don't perform needless computation of `inner`.
1756///
1757/// `branch` will inspect the contents of `inner` to determine whether `inner`
1758/// is not multiplicity sensitive (roughly, contains only maps, filters,
1759/// projections, and calls to table functions). If it is not multiplicity
1760/// sensitive, `branch` will *not* distinctify outer. If this is problematic,
1761/// e.g. because the `apply` callback itself introduces multiplicity-sensitive
1762/// operations that were not present in `inner`, then set
1763/// `apply_requires_distinct_outer` to ensure that `branch` chooses the plan
1764/// that distinctifies `outer`.
1765///
1766/// The caller must supply the `apply` function that applies the rewritten
1767/// `inner` to `outer`.
1768fn branch<F>(
1769 id_gen: &mut mz_ore::id_gen::IdGen,
1770 outer: MirRelationExpr,
1771 col_map: &ColumnMap,
1772 cte_map: &mut CteMap,
1773 inner: HirRelationExpr,
1774 apply_requires_distinct_outer: bool,
1775 context: &Context,
1776 apply: F,
1777) -> Result<MirRelationExpr, PlanError>
1778where
1779 F: FnOnce(
1780 &mut mz_ore::id_gen::IdGen,
1781 HirRelationExpr,
1782 MirRelationExpr,
1783 &ColumnMap,
1784 &mut CteMap,
1785 &Context,
1786 ) -> Result<MirRelationExpr, PlanError>,
1787{
1788 // TODO: It would be nice to have a version of this code w/o optimizations,
1789 // at the least for purposes of understanding. It was difficult for one reader
1790 // to understand the required properties of `outer` and `col_map`.
1791
1792 // If the inner expression is sufficiently simple, it is safe to apply it
1793 // *directly* to outer, rather than applying it to the distinctified key
1794 // (see below).
1795 //
1796 // As an example, consider the following two queries:
1797 //
1798 // CREATE TABLE t (a int, b int);
1799 // SELECT a, series FROM t, generate_series(1, t.b) series;
1800 //
1801 // The "simple" path for the `SELECT` yields
1802 //
1803 // %0 =
1804 // | Get t
1805 // | FlatMap generate_series(1, #1)
1806 //
1807 // while the non-simple path yields:
1808 //
1809 // %0 =
1810 // | Get t
1811 //
1812 // %1 =
1813 // | Get t
1814 // | Distinct group=(#1)
1815 // | FlatMap generate_series(1, #0)
1816 //
1817 // %2 =
1818 // | LeftJoin %1 %2 (= #1 #2)
1819 //
1820 // There is a tradeoff here: the simple plan is stateless, but the non-
1821 // simple plan may do (much) less computation if there are only a few
1822 // distinct values of `t.b`.
1823 //
1824 // We apply a very simple heuristic here and take the simple path if `inner`
1825 // contains only maps, filters, projections, and calls to table functions.
1826 // The intuition is that straightforward usage of table functions should
1827 // take the simple path, while everything else should not. (In theory we
1828 // think this transformation is valid as long as `inner` does not contain a
1829 // Reduce, Distinct, or TopK node, but it is not always an optimization in
1830 // the general case.)
1831 //
1832 // TODO(benesch): this should all be handled by a proper optimizer, but
1833 // detecting the moment of decorrelation in the optimizer right now is too
1834 // hard.
1835 let mut is_simple = true;
1836 #[allow(deprecated)]
1837 inner.visit(0, &mut |expr, _| match expr {
1838 HirRelationExpr::Constant { .. }
1839 | HirRelationExpr::Project { .. }
1840 | HirRelationExpr::Map { .. }
1841 | HirRelationExpr::Filter { .. }
1842 | HirRelationExpr::CallTable { .. } => (),
1843 _ => is_simple = false,
1844 });
1845 if is_simple && !apply_requires_distinct_outer {
1846 let new_col_map = col_map.enter_scope(outer.arity() - col_map.len());
1847 return outer.let_in(id_gen, |id_gen, get_outer| {
1848 apply(id_gen, inner, get_outer, &new_col_map, cte_map, context)
1849 });
1850 }
1851
1852 // The key consists of the columns from the outer expression upon which the
1853 // inner relation depends. We discover these dependencies by walking the
1854 // inner relation expression and looking for column references whose level
1855 // escapes inner.
1856 //
1857 // At the end of this process, `key` contains the decorrelated position of
1858 // each outer column, according to the passed-in `col_map`, and
1859 // `new_col_map` maps each outer column to its new ordinal position in key.
1860 let mut outer_cols = BTreeSet::new();
1861 #[allow(deprecated)]
1862 inner.visit_columns(0, &mut |depth, col| {
1863 // Test if the column reference escapes the subquery.
1864 if col.level > depth {
1865 outer_cols.insert(ColumnRef {
1866 level: col.level - depth,
1867 column: col.column,
1868 });
1869 }
1870 });
1871 // Collect all the outer columns referenced by any CTE referenced by
1872 // the inner relation.
1873 #[allow(deprecated)]
1874 inner.visit(0, &mut |e, _| match e {
1875 HirRelationExpr::Get {
1876 id: mz_expr::Id::Local(id),
1877 ..
1878 } => {
1879 if let Some(cte_desc) = cte_map.get(id) {
1880 let cte_outer_arity = cte_desc.outer_relation.arity();
1881 outer_cols.extend(
1882 col_map
1883 .inner
1884 .iter()
1885 .filter(|(_, position)| **position < cte_outer_arity)
1886 .map(|(c, _)| {
1887 // `col_map` maps column references to column positions in
1888 // `outer`'s projection.
1889 // `outer_cols` is meant to contain the external column
1890 // references in `inner`.
1891 // Since `inner` defines a new scope, any column reference
1892 // in `col_map` is one level deeper when seen from within
1893 // `inner`, hence the +1.
1894 ColumnRef {
1895 level: c.level + 1,
1896 column: c.column,
1897 }
1898 }),
1899 );
1900 }
1901 }
1902 HirRelationExpr::Let { id, .. } => {
1903 // Note: if ID uniqueness is not guaranteed, we can't use `visit` since
1904 // we would need to remove the old CTE with the same ID temporarily while
1905 // traversing the definition of the new CTE under the same ID.
1906 assert!(!cte_map.contains_key(id));
1907 }
1908 _ => {}
1909 });
1910 let mut new_col_map = BTreeMap::new();
1911 let mut key = vec![];
1912 for col in outer_cols {
1913 new_col_map.insert(col, key.len());
1914 key.push(col_map.get(&ColumnRef {
1915 // Note: `outer_cols` contains the external column references within `inner`.
1916 // We must compensate for `inner`'s scope when translating column references
1917 // as seen within `inner` to column references as seen from `outer`'s context,
1918 // hence the -1.
1919 level: col.level - 1,
1920 column: col.column,
1921 }));
1922 }
1923 let new_col_map = ColumnMap::new(new_col_map);
1924 outer.let_in(id_gen, |id_gen, get_outer| {
1925 let keyed_outer = if key.is_empty() {
1926 // Don't depend on outer at all if the branch is not correlated,
1927 // which yields vastly better query plans. Note that this is a bit
1928 // weird in that the branch will be computed even if outer has no
1929 // rows, whereas if it had been correlated it would not (and *could*
1930 // not) have been computed if outer had no rows, but the callers of
1931 // this function don't mind these somewhat-weird semantics.
1932 MirRelationExpr::constant(vec![vec![]], ReprRelationType::new(vec![]))
1933 } else {
1934 get_outer.clone().distinct_by(key.clone())
1935 };
1936 keyed_outer.let_in(id_gen, |id_gen, get_keyed_outer| {
1937 let oa = get_outer.arity();
1938 let branch = apply(
1939 id_gen,
1940 inner,
1941 get_keyed_outer,
1942 &new_col_map,
1943 cte_map,
1944 context,
1945 )?;
1946 let ba = branch.arity();
1947 let joined = MirRelationExpr::join(
1948 vec![get_outer.clone(), branch],
1949 key.iter()
1950 .enumerate()
1951 .map(|(i, &k)| vec![(0, k), (1, i)])
1952 .collect(),
1953 )
1954 // throw away the right-hand copy of the key we just joined on
1955 .project((0..oa).chain((oa + key.len())..(oa + ba)).collect());
1956 Ok(joined)
1957 })
1958 })
1959}
1960
1961fn apply_scalar_subquery(
1962 id_gen: &mut mz_ore::id_gen::IdGen,
1963 outer: MirRelationExpr,
1964 col_map: &ColumnMap,
1965 cte_map: &mut CteMap,
1966 scalar_subquery: HirRelationExpr,
1967 apply_requires_distinct_outer: bool,
1968 context: &Context,
1969) -> Result<MirRelationExpr, PlanError> {
1970 branch(
1971 id_gen,
1972 outer,
1973 col_map,
1974 cte_map,
1975 scalar_subquery,
1976 apply_requires_distinct_outer,
1977 context,
1978 |id_gen, expr, get_inner, col_map, cte_map, context| {
1979 // compute for every row in get_inner
1980 let select = expr.applied_to(id_gen, get_inner.clone(), col_map, cte_map, context)?;
1981 let col_type = select.sql_typ().column_types.into_last();
1982
1983 let inner_arity = get_inner.arity();
1984 // We must determine a count for each `get_inner` prefix,
1985 // and report an error if that count exceeds one.
1986 let guarded = select.let_in(id_gen, |_id_gen, get_select| {
1987 // Count for each `get_inner` prefix.
1988 let counts = get_select.clone().reduce(
1989 (0..inner_arity).collect::<Vec<_>>(),
1990 vec![mz_expr::AggregateExpr {
1991 func: mz_expr::AggregateFunc::Count,
1992 expr: MirScalarExpr::literal_true(),
1993 distinct: false,
1994 }],
1995 None,
1996 );
1997
1998 let use_guard = context.config.enable_guard_subquery_tablefunc;
1999
2000 // Errors should result from counts > 1.
2001 let errors = if use_guard {
2002 counts
2003 .flat_map(
2004 mz_expr::TableFunc::GuardSubquerySize {
2005 column_type: col_type.clone().scalar_type,
2006 },
2007 vec![MirScalarExpr::column(inner_arity)],
2008 )
2009 .project(
2010 (0..inner_arity)
2011 .chain(Some(inner_arity + 1))
2012 .collect::<Vec<_>>(),
2013 )
2014 } else {
2015 counts
2016 .filter(vec![MirScalarExpr::column(inner_arity).call_binary(
2017 MirScalarExpr::literal_ok(Datum::Int64(1), ReprScalarType::Int64),
2018 func::Gt,
2019 )])
2020 .project((0..inner_arity).collect::<Vec<_>>())
2021 .map_one(MirScalarExpr::literal(
2022 Err(mz_expr::EvalError::MultipleRowsFromSubquery),
2023 ReprScalarType::from(&col_type.scalar_type),
2024 ))
2025 };
2026 // Return `get_select` and any errors added in.
2027 Ok::<_, PlanError>(get_select.union(errors))
2028 })?;
2029 // append Null to anything that didn't return any rows
2030 let default = vec![(Datum::Null, ReprScalarType::from(&col_type.scalar_type))];
2031 get_inner.lookup(id_gen, guarded, default)
2032 },
2033 )
2034}
2035
2036fn apply_existential_subquery(
2037 id_gen: &mut mz_ore::id_gen::IdGen,
2038 outer: MirRelationExpr,
2039 col_map: &ColumnMap,
2040 cte_map: &mut CteMap,
2041 subquery_expr: HirRelationExpr,
2042 apply_requires_distinct_outer: bool,
2043 context: &Context,
2044) -> Result<MirRelationExpr, PlanError> {
2045 branch(
2046 id_gen,
2047 outer,
2048 col_map,
2049 cte_map,
2050 subquery_expr,
2051 apply_requires_distinct_outer,
2052 context,
2053 |id_gen, expr, get_inner, col_map, cte_map, context| {
2054 let exists = expr
2055 // compute for every row in get_inner
2056 .applied_to(id_gen, get_inner.clone(), col_map, cte_map, context)?
2057 // throw away actual values and just remember whether or not there were __any__ rows
2058 .distinct_by((0..get_inner.arity()).collect())
2059 // Append true to anything that returned any rows.
2060 .map(vec![MirScalarExpr::literal_true()]);
2061
2062 // append False to anything that didn't return any rows
2063 get_inner.lookup(id_gen, exists, vec![(Datum::False, ReprScalarType::Bool)])
2064 },
2065 )
2066}
2067
2068impl AggregateExpr {
2069 fn applied_to(
2070 self,
2071 id_gen: &mut mz_ore::id_gen::IdGen,
2072 col_map: &ColumnMap,
2073 cte_map: &mut CteMap,
2074 inner: &mut MirRelationExpr,
2075 context: &Context,
2076 ) -> Result<mz_expr::AggregateExpr, PlanError> {
2077 let AggregateExpr {
2078 func,
2079 expr,
2080 distinct,
2081 } = self;
2082
2083 Ok(mz_expr::AggregateExpr {
2084 func: func.into_expr(),
2085 expr: expr.applied_to(id_gen, col_map, cte_map, inner, &None, context)?,
2086 distinct,
2087 })
2088 }
2089}
2090
2091/// Attempts an efficient outer join, if `on` has equijoin structure.
2092///
2093/// Both `left` and `right` are decorrelated inputs.
2094///
2095/// The first `oa` columns correspond to an outer context: we should do the
2096/// outer join independently for each prefix. In the case that `on` contains
2097/// just some equality tests between columns of `left` and `right` and some
2098/// local predicates, we can employ a relatively simple plan.
2099///
2100/// The last `on_subquery_types.len()` columns correspond to results from
2101/// subqueries defined in the `on` clause - we treat those as theta-join
2102/// conditions that prohibit the use of the simple plan attempted here.
2103fn attempt_outer_equijoin(
2104 left: MirRelationExpr,
2105 right: MirRelationExpr,
2106 on: MirScalarExpr,
2107 on_subquery_types: Vec<ReprColumnType>,
2108 kind: JoinKind,
2109 oa: usize,
2110 id_gen: &mut mz_ore::id_gen::IdGen,
2111 context: &Context,
2112) -> Result<Option<MirRelationExpr>, PlanError> {
2113 // TODO(database-issues#6827): In theory, we can be smarter and also handle `on`
2114 // predicates that reference subqueries as long as these subqueries don't
2115 // reference `left` and `right` at the same time.
2116 //
2117 // TODO(database-issues#6828): This code can be improved as follows:
2118 //
2119 // 1. Move the `canonicalize_predicates(...)` call to `applied_to`.
2120 // 2. Use the canonicalized `on` predicate in the non-equijoin based
2121 // lowering strategy.
2122 // 3. Move the `OnPredicates::new(...)` call to `applied_to`.
2123 // 4. Pass the classified `OnPredicates` as a parameter.
2124 // 5. Guard calls of this function with `on_predicates.is_equijoin()`.
2125 //
2126 // Steps (1 + 2) require further investigation because we might change the
2127 // error semantics in case the `on` predicate contains a literal error..
2128
2129 let l_type = left.typ();
2130 let r_type = right.typ();
2131 let la = l_type.column_types.len() - oa;
2132 let ra = r_type.column_types.len() - oa;
2133 let sa = on_subquery_types.len();
2134
2135 // The output type contains [outer, left, right, sa] attributes.
2136 let mut output_type = Vec::with_capacity(oa + la + ra + sa);
2137 output_type.extend(l_type.column_types);
2138 output_type.extend(r_type.column_types.into_iter().skip(oa));
2139 output_type.extend(on_subquery_types);
2140
2141 // Generally healthy to do, but specifically `USING` conditions sometimes
2142 // put an `AND true` at the end of the `ON` condition.
2143 //
2144 // TODO(aalexandrov): maybe we should already be doing this in `applied_to`.
2145 // However, in that case it's not clear that we won't see regressions if
2146 // `on` simplifies to a literal error.
2147 let mut on = vec![on];
2148 mz_expr::canonicalize::canonicalize_predicates(&mut on, &output_type);
2149
2150 // Form the left and right types without the outer attributes.
2151 output_type.drain(0..oa);
2152 let lt = output_type.drain(0..la).collect_vec();
2153 let rt = output_type.drain(0..ra).collect_vec();
2154 assert!(output_type.len() == sa);
2155
2156 let on_predicates = OnPredicates::new(oa, la, ra, sa, on.clone(), context);
2157 if !on_predicates.is_equijoin(context) {
2158 return Ok(None);
2159 }
2160
2161 // If we've gotten this far, we can do the clever thing.
2162 // We'll want to use left and right multiple times
2163 let result = left.let_in(id_gen, |id_gen, get_left| {
2164 right.let_in(id_gen, |id_gen, get_right| {
2165 // TODO: we know that we can re-use the arrangements of left and right
2166 // needed for the inner join with each of the conditional outer joins.
2167 // It is not clear whether we should hint that, or just let the planner
2168 // and optimizer run and see what happens.
2169
2170 // We'll want the inner join (minus repeated columns)
2171 let join = MirRelationExpr::join(
2172 vec![get_left.clone(), get_right.clone()],
2173 (0..oa).map(|i| vec![(0, i), (1, i)]).collect(),
2174 )
2175 // remove those columns from `right` repeating the first `oa` columns.
2176 .project(
2177 (0..(oa + la))
2178 .chain((oa + la + oa)..(oa + la + oa + ra))
2179 .collect(),
2180 )
2181 // apply the filter constraints here, to ensure nulls are not matched.
2182 .filter(on);
2183
2184 // We'll want to re-use the results of the join multiple times.
2185 join.let_in(id_gen, |id_gen, get_join| {
2186 let mut result = get_join.clone();
2187
2188 // A collection of keys present in both left and right collections.
2189 let join_keys = on_predicates.join_keys();
2190 let both_keys_arity = join_keys.len();
2191 let both_keys = get_join.restrict(join_keys).distinct();
2192
2193 // The plan is now to determine the left and right rows matched in the
2194 // inner join, subtract them from left and right respectively, pad what
2195 // remains with nulls, and fold them in to `result`.
2196
2197 both_keys.let_in(id_gen, |_id_gen, get_both| {
2198 if let JoinKind::LeftOuter { .. } | JoinKind::FullOuter = kind {
2199 // Rows in `left` matched in the inner equijoin. This is
2200 // a semi-join between `left` and `both_keys`.
2201 let left_present = MirRelationExpr::join_scalars(
2202 vec![
2203 get_left
2204 .clone()
2205 // Push local predicates.
2206 .filter(on_predicates.lhs()),
2207 get_both.clone(),
2208 ],
2209 itertools::zip_eq(
2210 on_predicates.eq_lhs(),
2211 (0..both_keys_arity).map(|k| MirScalarExpr::column(oa + la + k)),
2212 )
2213 .map(|(l_key, b_key)| [l_key, b_key].to_vec())
2214 .collect(),
2215 )
2216 .project((0..(oa + la)).collect());
2217
2218 // Determine the types of nulls to use as filler.
2219 let right_fill = rt
2220 .into_iter()
2221 .map(|typ| MirScalarExpr::literal_null(typ.scalar_type))
2222 .collect();
2223 // Add to `result` absent elements, filled with typed nulls.
2224 result = left_present
2225 .negate()
2226 .union(get_left.clone())
2227 .map(right_fill)
2228 .union(result);
2229 }
2230
2231 if let JoinKind::RightOuter | JoinKind::FullOuter = kind {
2232 // Rows in `right` matched in the inner equijoin. This
2233 // is a semi-join between `right` and `both_keys`.
2234 let right_present = MirRelationExpr::join_scalars(
2235 vec![
2236 get_right
2237 .clone()
2238 // Push local predicates.
2239 .filter(on_predicates.rhs()),
2240 get_both,
2241 ],
2242 itertools::zip_eq(
2243 on_predicates.eq_rhs(),
2244 (0..both_keys_arity).map(|k| MirScalarExpr::column(oa + ra + k)),
2245 )
2246 .map(|(r_key, b_key)| [r_key, b_key].to_vec())
2247 .collect(),
2248 )
2249 .project((0..(oa + ra)).collect());
2250
2251 // Determine the types of nulls to use as filler.
2252 let left_fill = lt
2253 .into_iter()
2254 .map(|typ| MirScalarExpr::literal_null(typ.scalar_type))
2255 .collect();
2256
2257 // Add to `result` absent elements, prepended with typed nulls.
2258 result = right_present
2259 .negate()
2260 .union(get_right.clone())
2261 .map(left_fill)
2262 // Permute left fill before right values.
2263 .project(
2264 itertools::chain!(
2265 0..oa, // Preserve `outer`.
2266 oa + ra..oa + la + ra, // Increment the next `la` cols by `ra`.
2267 oa..oa + ra // Decrement the next `ra` cols by `la`.
2268 )
2269 .collect(),
2270 )
2271 .union(result)
2272 }
2273
2274 Ok::<_, PlanError>(result)
2275 })
2276 })
2277 })
2278 })?;
2279 Ok(Some(result))
2280}
2281
2282/// A struct that represents the predicates in the `on` clause in a form
2283/// suitable for efficient planning outer joins with equijoin predicates.
2284struct OnPredicates {
2285 /// A store for classified `ON` predicates.
2286 ///
2287 /// Predicates that reference a single side are adjusted to assume an
2288 /// `outer × <side>` schema.
2289 predicates: Vec<OnPredicate>,
2290 /// Number of outer context columns.
2291 oa: usize,
2292}
2293
2294impl OnPredicates {
2295 const I_OUT: usize = 0; // outer context input position
2296 const I_LHS: usize = 1; // lhs input position
2297 const I_RHS: usize = 2; // rhs input position
2298 const I_SUB: usize = 3; // on subqueries input position
2299
2300 /// Classify the predicates in the `on` clause of an outer join.
2301 ///
2302 /// The other parameters are arities of the input parts:
2303 ///
2304 /// - `oa` is the arity of the `outer` context.
2305 /// - `la` is the arity of the `left` input.
2306 /// - `ra` is the arity of the `right` input.
2307 /// - `sa` is the arity of the `on` subqueries.
2308 ///
2309 /// The constructor assumes that:
2310 ///
2311 /// 1. The `on` parameter will be applied on a result that has the following
2312 /// schema `outer × left × right × on_subqueries`.
2313 /// 2. The `on` parameter is already adjusted to assume that schema.
2314 /// 3. The `on` parameter is obtained by canonicalizing the original `on:
2315 /// MirScalarExpr` with `canonicalize_predicates`.
2316 fn new(
2317 oa: usize,
2318 la: usize,
2319 ra: usize,
2320 sa: usize,
2321 on: Vec<MirScalarExpr>,
2322 _context: &Context,
2323 ) -> Self {
2324 use mz_expr::BinaryFunc::Eq;
2325
2326 // Re-bind those locally for more compact pattern matching.
2327 const I_LHS: usize = OnPredicates::I_LHS;
2328 const I_RHS: usize = OnPredicates::I_RHS;
2329
2330 // Self parameters.
2331 let mut predicates = Vec::with_capacity(on.len());
2332
2333 // Helpers for populating `predicates`.
2334 let inner_join_mapper = mz_expr::JoinInputMapper::new_from_input_arities([oa, la, ra, sa]);
2335 let rhs_permutation = itertools::chain!(0..oa + la, oa..oa + ra).collect::<Vec<_>>();
2336 let lookup_inputs = |expr: &MirScalarExpr| -> Vec<usize> {
2337 inner_join_mapper
2338 .lookup_inputs(expr)
2339 .filter(|&i| i != Self::I_OUT)
2340 .collect()
2341 };
2342 let has_subquery_refs = |expr: &MirScalarExpr| -> bool {
2343 inner_join_mapper
2344 .lookup_inputs(expr)
2345 .any(|i| i == Self::I_SUB)
2346 };
2347
2348 // Iterate over `on` elements and populate `predicates`.
2349 for mut predicate in on {
2350 if predicate.might_error() {
2351 tracing::debug!(case = "thetajoin (error)", "OnPredicates::new");
2352 // Treat predicates that can produce a literal error as Theta.
2353 predicates.push(OnPredicate::Theta(predicate));
2354 } else if has_subquery_refs(&predicate) {
2355 tracing::debug!(case = "thetajoin (subquery)", "OnPredicates::new");
2356 // Treat predicates referencing an `on` subquery as Theta.
2357 predicates.push(OnPredicate::Theta(predicate));
2358 } else if let MirScalarExpr::CallBinary {
2359 func: Eq(_),
2360 expr1,
2361 expr2,
2362 } = &mut predicate
2363 {
2364 // Obtain the non-outer inputs referenced by each side.
2365 let inputs1 = lookup_inputs(expr1);
2366 let inputs2 = lookup_inputs(expr2);
2367
2368 match (&inputs1[..], &inputs2[..]) {
2369 // Neither side references an input. This could be a
2370 // constant expression or an expression that depends only on
2371 // the outer context.
2372 ([], []) => {
2373 predicates.push(OnPredicate::Const(predicate));
2374 }
2375 // Both sides reference different inputs.
2376 ([I_LHS], [I_RHS]) => {
2377 let lhs = expr1.take();
2378 let mut rhs = expr2.take();
2379 rhs.permute(&rhs_permutation);
2380 predicates.push(OnPredicate::Eq(lhs.clone(), rhs.clone()));
2381 predicates.push(OnPredicate::LhsConsequence(lhs.call_is_null().not()));
2382 predicates.push(OnPredicate::RhsConsequence(rhs.call_is_null().not()));
2383 }
2384 // Both sides reference different inputs (swapped).
2385 ([I_RHS], [I_LHS]) => {
2386 let lhs = expr2.take();
2387 let mut rhs = expr1.take();
2388 rhs.permute(&rhs_permutation);
2389 predicates.push(OnPredicate::Eq(lhs.clone(), rhs.clone()));
2390 predicates.push(OnPredicate::LhsConsequence(lhs.call_is_null().not()));
2391 predicates.push(OnPredicate::RhsConsequence(rhs.call_is_null().not()));
2392 }
2393 // Both sides reference the left input or no input.
2394 ([I_LHS], [I_LHS]) | ([I_LHS], []) | ([], [I_LHS]) => {
2395 predicates.push(OnPredicate::Lhs(predicate));
2396 }
2397 // Both sides reference the right input or no input.
2398 ([I_RHS], [I_RHS]) | ([I_RHS], []) | ([], [I_RHS]) => {
2399 predicate.permute(&rhs_permutation);
2400 predicates.push(OnPredicate::Rhs(predicate));
2401 }
2402 // At least one side references more than one input.
2403 _ => {
2404 tracing::debug!(case = "thetajoin (eq)", "OnPredicates::new");
2405 predicates.push(OnPredicate::Theta(predicate));
2406 }
2407 }
2408 } else {
2409 // Obtain the non-outer inputs referenced by this predicate.
2410 let inputs = lookup_inputs(&predicate);
2411
2412 match &inputs[..] {
2413 // The predicate references no inputs. This could be a
2414 // constant expression or an expression that depends only on
2415 // the outer context.
2416 [] => {
2417 predicates.push(OnPredicate::Const(predicate));
2418 }
2419 // The predicate references only the left input.
2420 [I_LHS] => {
2421 predicates.push(OnPredicate::Lhs(predicate));
2422 }
2423 // The predicate references only the right input.
2424 [I_RHS] => {
2425 predicate.permute(&rhs_permutation);
2426 predicates.push(OnPredicate::Rhs(predicate));
2427 }
2428 // The predicate references both inputs.
2429 _ => {
2430 tracing::debug!(case = "thetajoin (non-eq)", "OnPredicates::new");
2431 predicates.push(OnPredicate::Theta(predicate));
2432 }
2433 }
2434 }
2435 }
2436
2437 Self { predicates, oa }
2438 }
2439
2440 /// Check if the predicates can be lowered with an equijoin-based strategy.
2441 fn is_equijoin(&self, context: &Context) -> bool {
2442 // Count each `OnPredicate` variant in `self.predicates`.
2443 let (const_cnt, lhs_cnt, rhs_cnt, eq_cnt, eq_cols, theta_cnt) =
2444 self.predicates.iter().fold(
2445 (0, 0, 0, 0, 0, 0),
2446 |(const_cnt, lhs_cnt, rhs_cnt, eq_cnt, eq_cols, theta_cnt), p| {
2447 (
2448 const_cnt + usize::from(matches!(p, OnPredicate::Const(..))),
2449 lhs_cnt + usize::from(matches!(p, OnPredicate::Lhs(..))),
2450 rhs_cnt + usize::from(matches!(p, OnPredicate::Rhs(..))),
2451 eq_cnt + usize::from(matches!(p, OnPredicate::Eq(..))),
2452 eq_cols
2453 + usize::from(matches!(
2454 p,
2455 OnPredicate::Eq(lhs, rhs) if lhs.is_column() && rhs.is_column()
2456 )),
2457 theta_cnt + usize::from(matches!(p, OnPredicate::Theta(..))),
2458 )
2459 },
2460 );
2461
2462 let is_equijion = if context.config.enable_new_outer_join_lowering {
2463 // New classifier.
2464 eq_cnt > 0 && theta_cnt == 0
2465 } else {
2466 // Old classifier.
2467 eq_cnt > 0 && eq_cnt == eq_cols && theta_cnt + const_cnt + lhs_cnt + rhs_cnt == 0
2468 };
2469
2470 // Log an entry only if this is an equijoin according to the new classifier.
2471 if eq_cnt > 0 && theta_cnt == 0 {
2472 tracing::debug!(
2473 const_cnt,
2474 lhs_cnt,
2475 rhs_cnt,
2476 eq_cnt,
2477 eq_cols,
2478 theta_cnt,
2479 "OnPredicates::is_equijoin"
2480 );
2481 }
2482
2483 is_equijion
2484 }
2485
2486 /// Return an [`MirRelationExpr`] list that represents the keys for the
2487 /// equijoin. The list will contain the outer columns as a prefix.
2488 fn join_keys(&self) -> JoinKeys {
2489 // We could return either the `lhs` or the `rhs` of the keys used to
2490 // form the inner join as they are equated by the join condition.
2491 let join_keys = self.eq_lhs().collect::<Vec<_>>();
2492
2493 if join_keys.iter().all(|k| k.is_column()) {
2494 tracing::debug!(case = "outputs", "OnPredicates::join_keys");
2495 JoinKeys::Outputs(join_keys.iter().flat_map(|k| k.as_column()).collect())
2496 } else {
2497 tracing::debug!(case = "scalars", "OnPredicates::join_keys");
2498 JoinKeys::Scalars(join_keys)
2499 }
2500 }
2501
2502 /// Return an iterator over the left-hand sides of all [`OnPredicate::Eq`]
2503 /// conditions in the predicates list.
2504 ///
2505 /// The iterator will start with column references to the outer columns as a
2506 /// prefix.
2507 fn eq_lhs(&self) -> impl Iterator<Item = MirScalarExpr> + '_ {
2508 itertools::chain(
2509 (0..self.oa).map(MirScalarExpr::column),
2510 self.predicates.iter().filter_map(|e| match e {
2511 OnPredicate::Eq(lhs, _) => Some(lhs.clone()),
2512 _ => None,
2513 }),
2514 )
2515 }
2516
2517 /// Return an iterator over the right-hand sides of all [`OnPredicate::Eq`]
2518 /// conditions in the predicates list.
2519 ///
2520 /// The iterator will start with column references to the outer columns as a
2521 /// prefix.
2522 fn eq_rhs(&self) -> impl Iterator<Item = MirScalarExpr> + '_ {
2523 itertools::chain(
2524 (0..self.oa).map(MirScalarExpr::column),
2525 self.predicates.iter().filter_map(|e| match e {
2526 OnPredicate::Eq(_, rhs) => Some(rhs.clone()),
2527 _ => None,
2528 }),
2529 )
2530 }
2531
2532 /// Return an iterator over the [`OnPredicate::Lhs`], [`OnPredicate::LhsConsequence`] and
2533 /// [`OnPredicate::Const`] conditions in the predicates list.
2534 fn lhs(&self) -> impl Iterator<Item = MirScalarExpr> + '_ {
2535 self.predicates.iter().filter_map(|p| match p {
2536 // We treat Const predicates local to both inputs.
2537 OnPredicate::Const(p) => Some(p.clone()),
2538 OnPredicate::Lhs(p) => Some(p.clone()),
2539 OnPredicate::LhsConsequence(p) => Some(p.clone()),
2540 _ => None,
2541 })
2542 }
2543
2544 /// Return an iterator over the [`OnPredicate::Rhs`], [`OnPredicate::RhsConsequence`] and
2545 /// [`OnPredicate::Const`] conditions in the predicates list.
2546 fn rhs(&self) -> impl Iterator<Item = MirScalarExpr> + '_ {
2547 self.predicates.iter().filter_map(|p| match p {
2548 // We treat Const predicates local to both inputs.
2549 OnPredicate::Const(p) => Some(p.clone()),
2550 OnPredicate::Rhs(p) => Some(p.clone()),
2551 OnPredicate::RhsConsequence(p) => Some(p.clone()),
2552 _ => None,
2553 })
2554 }
2555}
2556
2557enum OnPredicate {
2558 // A predicate that is either constant or references only outer columns.
2559 Const(MirScalarExpr),
2560 // A local predicate on the left-hand side of the join, i.e., it references only the left input
2561 // and possibly outer columns.
2562 //
2563 // This is one of the original predicates from the ON clause.
2564 //
2565 // One _must_ apply this predicate.
2566 Lhs(MirScalarExpr),
2567 // A local predicate on the left-hand side of the join, i.e., it references only the left input
2568 // and possibly outer columns.
2569 //
2570 // This is not one of the original predicates from the ON clause, but is just a consequence
2571 // of an original predicate in the ON clause, where the original predicate references both
2572 // inputs, but the consequence references only the left input.
2573 //
2574 // For example, the original predicate `input1.x = input2.a` has the consequence
2575 // `input1.x IS NOT NULL`. Applying such a consequence before the input is fed into the join
2576 // prevents null skew, and also makes more CSE opportunities available when the left input's key
2577 // doesn't have a NOT NULL constraint, saving us an arrangement.
2578 //
2579 // Applying the predicate is optional, because the original predicate will be applied anyway.
2580 LhsConsequence(MirScalarExpr),
2581 // A local predicate on the right-hand side of the join.
2582 //
2583 // This is one of the original predicates from the ON clause.
2584 //
2585 // One _must_ apply this predicate.
2586 Rhs(MirScalarExpr),
2587 // A consequence of an original ON predicate, see above.
2588 RhsConsequence(MirScalarExpr),
2589 // An equality predicate between the two sides.
2590 Eq(MirScalarExpr, MirScalarExpr),
2591 // a non-equality predicate between the two sides.
2592 #[allow(dead_code)]
2593 Theta(MirScalarExpr),
2594}
2595
2596/// A set of join keys referencing an input.
2597///
2598/// This is used in the [`MirRelationExpr::Join`] lowering code in order to
2599/// avoid changes (and thereby possible regressions) in plans that have equijoin
2600/// predicates consisting only of column refs.
2601///
2602/// If we were running `CanonicalizeMfp` as part of `NormalizeOps` we might be
2603/// able to get rid of this code, but as it stands `Map` simplification seems
2604/// more cumbersome than `Project` simplification, so do this just to be sure.
2605enum JoinKeys {
2606 // A predicate that is either constant or references only outer columns.
2607 Outputs(Vec<usize>),
2608 // A local predicate on the left-hand side of the join.
2609 Scalars(Vec<MirScalarExpr>),
2610}
2611
2612impl JoinKeys {
2613 fn len(&self) -> usize {
2614 match self {
2615 JoinKeys::Outputs(outputs) => outputs.len(),
2616 JoinKeys::Scalars(scalars) => scalars.len(),
2617 }
2618 }
2619}
2620
2621/// Extension methods for [`MirRelationExpr`] required in the HIR ⇒ MIR lowering
2622/// code.
2623trait LoweringExt {
2624 /// See [`MirRelationExpr::restrict`].
2625 fn restrict(self, join_keys: JoinKeys) -> Self;
2626}
2627
2628impl LoweringExt for MirRelationExpr {
2629 /// Restrict the set of columns of an input to the sequence of [`JoinKeys`].
2630 fn restrict(self, join_keys: JoinKeys) -> Self {
2631 let num_keys = join_keys.len();
2632 match join_keys {
2633 JoinKeys::Outputs(outputs) => self.project(outputs),
2634 JoinKeys::Scalars(scalars) => {
2635 let input_arity = self.arity();
2636 let outputs = (input_arity..input_arity + num_keys).collect();
2637 self.map(scalars).project(outputs)
2638 }
2639 }
2640 }
2641}