mz_sql/plan/lowering/variadic_left.rs
1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use itertools::Itertools;
11use mz_expr::{MirRelationExpr, MirScalarExpr, func};
12use mz_ore::soft_assert_eq_or_log;
13use mz_repr::Diff;
14
15use crate::plan::PlanError;
16use crate::plan::hir::{HirRelationExpr, HirScalarExpr};
17use crate::plan::lowering::{ColumnMap, Context, CteMap};
18
19/// Attempt to render a stack of left joins as an inner join against "enriched" right relations.
20///
21/// This optimization applies for a contiguous block of left joins where the `right` term is not
22/// correlated, and where the `on` constraints equate columns in `right` to expressions over some
23/// single prior joined relation (`left`, or a prior `right`).
24///
25/// The plan is to enrich each `right` with any missing key values, extracted by applying the equated
26/// expressions to the source collection and then introducing them to an "augmented" right relation.
27/// The introduced records are augmented with null values where missing, and an additional column that
28/// indicates whether the data are original or augmented (important for masking out introduced keys).
29///
30/// Importantly, we need to introduce the constraints that equate columns and expressions in the `Join`,
31/// as a `Filter` will still use SQL's equality, which treats NULL as unequal (we want them to match).
32/// We could replace each `(col = expr)` with `(col = expr OR (col IS NULL AND expr IS NULL))`.
33pub(crate) fn attempt_left_join_magic(
34 left: &HirRelationExpr,
35 rights: Vec<(&HirRelationExpr, &HirScalarExpr)>,
36 id_gen: &mut mz_ore::id_gen::IdGen,
37 get_outer: MirRelationExpr,
38 col_map: &ColumnMap,
39 cte_map: &mut CteMap,
40 context: &Context,
41) -> Result<Option<MirRelationExpr>, PlanError> {
42 use mz_expr::LocalId;
43
44 let inc_metrics = |case: &str| {
45 if let Some(metrics) = context.metrics {
46 metrics.inc_outer_join_lowering(case);
47 }
48 };
49
50 let oa = get_outer.arity();
51 tracing::debug!(
52 inputs = rights.len() + 1,
53 outer_arity = oa,
54 "attempt_left_join_magic"
55 );
56
57 if oa > 0 {
58 // Bail out in correlated contexts for now. Even though the code below
59 // supports them, we want to test this code path more thoroughly before
60 // enabling this.
61 tracing::debug!(case = 1, oa, "attempt_left_join_magic");
62 inc_metrics("voj_1");
63 return Ok(None);
64 }
65
66 // Will contain a list of let binding obligations.
67 // We may modify the values if we find promising prior values.
68 let mut bindings = Vec::new();
69 let mut augmented = Vec::new();
70 // A vector associating result columns with their corresponding input number
71 // (where 0 indicates columns from the outer context).
72 let mut bound_to = (0..oa).map(|_| 0).collect::<Vec<_>>();
73 // A vector associating inputs with their arities (where the [0] entry
74 // corresponds to the arity of the outer context).
75 let mut arities = vec![oa];
76
77 // Left relation, its type, and its arity.
78 let left = left
79 .clone()
80 .applied_to(id_gen, get_outer.clone(), col_map, cte_map, context)?;
81 let full_left_typ = left.typ();
82 let lt = full_left_typ
83 .column_types
84 .iter()
85 .skip(oa)
86 .cloned()
87 .collect_vec();
88 let la = lt.len();
89
90 // Create a new let binding to use as input.
91 // We may use these relations multiple times to extract augmenting values.
92 let id = LocalId::new(id_gen.allocate_id());
93 // The join body that we will iteratively develop.
94 let mut body = MirRelationExpr::local_get(id, full_left_typ.clone());
95 bindings.push((id, body.clone(), left));
96 bound_to.extend((0..la).map(|_| 1));
97 arities.push(la);
98
99 // "body arity": number of columns in `body`; the join we are building.
100 let mut ba = la;
101
102 // For each LEFT JOIN, there is a `right` input and an `on` constraint.
103 // We want to decorrelate them, failing if there are subqueries because omg no,
104 // and then check to see if the decorrelated `on` equates RHS columns with values
105 // in one prior input. If so; bring those values into the mix, and bind that as
106 // the value of the `Let` binding.
107 for (index, (right, on)) in rights.into_iter().rev().enumerate() {
108 // Correlated right expressions are handled in a different branch than standard
109 // outer join lowering, and I don't know what they mean. Fail conservatively.
110 if right.is_correlated() {
111 tracing::debug!(case = 2, index, "attempt_left_join_magic");
112 inc_metrics("voj_2");
113 return Ok(None);
114 }
115
116 // Decorrelate `right`.
117 let right_col_map = col_map.enter_scope(0);
118 let right = right
119 .clone()
120 .map(vec![HirScalarExpr::literal_true()]) // add a bit to mark "real" rows.
121 .applied_to(id_gen, get_outer.clone(), &right_col_map, cte_map, context)?;
122 let full_right_typ = right.typ();
123 let rt = full_right_typ
124 .column_types
125 .iter()
126 .skip(oa)
127 .cloned()
128 .collect_vec();
129 let ra = rt.len() - 1; // don't count the new column
130
131 let mut right_type = full_right_typ;
132 // Create a binding for `right`, unadulterated.
133 let id = LocalId::new(id_gen.allocate_id());
134 let get_right = MirRelationExpr::local_get(id, right_type.clone());
135 // Create a binding for the augmented right, which we will form here but use before we do.
136 // We want the join to be based off of the augmented relation, but we don't yet know how
137 // to augment it until we decorrelate `on`. So, we use a `Get` binding that we backfill.
138 for column in right_type.column_types.iter_mut() {
139 column.nullable = true;
140 }
141 right_type.keys.clear();
142 let aug_id = LocalId::new(id_gen.allocate_id());
143 let aug_right = MirRelationExpr::local_get(aug_id, right_type.clone());
144
145 bindings.push((id, get_right.clone(), right));
146 bound_to.extend((0..ra).map(|_| 2 + index));
147 arities.push(ra);
148
149 // Cartesian join but equating the outer columns.
150 let mut product = MirRelationExpr::join(
151 vec![body, aug_right.clone()],
152 (0..oa).map(|i| vec![(0, i), (1, i)]).collect(),
153 )
154 // ... remove the second copy of the outer columns.
155 .project(
156 (0..(oa + ba))
157 .chain((oa + ba + oa)..(oa + ba + oa + ra + 1)) // include new column
158 .collect(),
159 );
160
161 // Decorrelate and lower the `on` clause.
162 let on = on
163 .clone()
164 .applied_to(id_gen, col_map, cte_map, &mut product, &None, context)?;
165
166 // if `on` added any new columns, .. no clue what to do.
167 // Return with failure, to avoid any confusion.
168 if product.arity() > oa + ba + ra + 1 {
169 tracing::debug!(case = 3, index, "attempt_left_join_magic");
170 inc_metrics("voj_3");
171 return Ok(None);
172 }
173
174 // If `on` equates columns in `right` with columns in some input,
175 // not just "any columns in `body`" but some single specific input,
176 // then we can fish out values from that input. If it equates values
177 // across multiple inputs, we would need to fish out valid tuples and
178 // no idea how we would get those w/o doing a join or a cartesian product.
179 let (equations, non_crossing_equations) =
180 if let Some(list) = decompose_left_to_right_equations(&on, oa + ba) {
181 list
182 } else {
183 tracing::debug!(case = 4, index, "attempt_left_join_magic");
184 inc_metrics("voj_4");
185 return Ok(None);
186 };
187
188 if !non_crossing_equations.is_empty() {
189 // TODO(mgree) This case isn't _impossible_, but it's complicated.
190 // We have equations that cross from left to right, but we also have
191 // left-left or right-right equations. Making sure we get exactly the
192 // right results here is hard enough that we don't attempt it.
193 tracing::debug!(case = 8, index, "attempt_left_join_magic");
194 inc_metrics("voj_8");
195 return Ok(None);
196 }
197
198 // We now need to see if all left columns exist in some input relation,
199 // and that all right columns are actually in the right relation. Idk.
200 // Left columns less than `oa` do not bind to an input, as they are for
201 // columns present in all inputs.
202 let mut bound_input = None;
203 for (left, right) in equations.iter().cloned() {
204 // If the right reference is not actually to `right`, bail out.
205 if right < oa + ba {
206 tracing::debug!(case = 5, index, "attempt_left_join_magic");
207 inc_metrics("voj_5");
208 return Ok(None);
209 }
210 // Only columns not from the outer scope introduce bindings (`oa <= left`)
211 // And `left` needs to be a column in the left relation (`left < oa + ba`)
212 if oa <= left && left < oa + ba {
213 if let Some(bound) = bound_input {
214 // If left references come from different inputs, bail out.
215 if bound_to[left] != bound {
216 tracing::debug!(case = 6, index, "attempt_left_join_magic");
217 inc_metrics("voj_6");
218 return Ok(None);
219 }
220 }
221 bound_input = Some(bound_to[left]);
222 }
223 }
224
225 if let Some(bound) = bound_input {
226 // This is great news; we have an input `bound` that we can augment,
227 // and just need to pull those values in to the definition of `right`.
228
229 // Add up prior arities, to learn what to subtract from left references.
230 // Don't subtract anything from left references less than `oa`!
231 let offset: usize = arities[0..bound].iter().sum();
232
233 // We now want to grab the `Get` for both left and right relations,
234 // which we will project to get distinct values, then difference and
235 // threshold to find those present in left but missing in right.
236 let get_left = &bindings[bound - 1].1;
237 // Set up a type for the all-nulls row we need to introduce.
238 let mut left_typ = get_left.typ();
239 for col in left_typ.column_types.iter_mut() {
240 col.nullable = true;
241 }
242 left_typ.keys.clear();
243 // `get_right` is already bound.
244
245 // Augment left_vals an all `Null` row, so that any null values
246 // match with nulls, and compute the distinct join keys in the
247 // resulting union.
248 let left_vals = MirRelationExpr::union(
249 get_left.clone(),
250 MirRelationExpr::Constant {
251 rows: Ok(vec![(
252 mz_repr::Row::pack(
253 std::iter::repeat(mz_repr::Datum::Null).take(left_typ.arity()),
254 ),
255 Diff::ONE,
256 )]),
257 typ: left_typ.clone(),
258 },
259 )
260 .project(
261 equations
262 .iter()
263 .map(|(l, _)| if l < &oa { *l } else { l - offset })
264 .collect::<Vec<_>>(),
265 )
266 .distinct();
267
268 // Compute the non-Null join keys on the right side. We skip the
269 // distinct because the eventual `threshold` between `left_vals` and
270 // `right_vals` protects us.
271 let right_vals = get_right
272 .clone()
273 // The #c1 IS NOT NULL AND ... AND #cn IS NOT NULL filter
274 // ensures that we won't remove the all `Null` row in the
275 // eventual `threshold` call.
276 .filter(
277 equations
278 .iter()
279 .map(|(_, r)| MirScalarExpr::column(r - oa - ba).call_is_null().not()),
280 )
281 // Retain only the keys referenced on the right side of the LEFT
282 // JOIN equations.
283 .project(
284 equations
285 .iter()
286 .map(|(_, r)| r - oa - ba)
287 .collect::<Vec<_>>(),
288 );
289
290 // Now we need to permute them into place, and leave `Datum::Null` values behind.
291 let additions = MirRelationExpr::union(right_vals.negate(), left_vals)
292 .threshold()
293 .map(
294 // Append nulls for all get_right columns, including the
295 // extra column at the end that is used to differentiate between
296 // augmented and original columns in the aug_value.
297 rt.iter()
298 .map(|t| MirScalarExpr::literal_null(t.scalar_type.clone()))
299 .collect::<Vec<_>>(),
300 )
301 .project({
302 // By default, we'll place post-pended nulls in each location.
303 // We will overwrite this with instructions to find augmenting values.
304
305 // Start with a projection that retains the last |rt|
306 // columns corresponding to the NULLs from the above
307 // .map(...) call.
308 let mut projection =
309 (equations.len()..equations.len() + rt.len()).collect::<Vec<_>>();
310 // Replace NULLs columns corresponding to rhs columns
311 // referenced in an ON equation with the actual rhs value
312 // (located at `index`).
313 for (index, (_, right)) in equations.iter().enumerate() {
314 projection[*right - oa - ba] = index;
315 }
316
317 projection
318 });
319
320 // This is where we should add a boolean column to indicate that the row is augmented,
321 // so that after the join is done we can overwrite all values for `right` with null values.
322 // This is a quirk of how outer joins work: the matched columns are left as null.
323
324 // TODO(aalexandrov): if we never see an error from this we can
325 // 1. Use `get_right` instead of `bindings[index + 1].1.clone()`.
326 // 2. Simplify bindings to use tuples instead of triples.
327 soft_assert_eq_or_log!(&bindings[index + 1].1, &get_right);
328
329 let aug_value = MirRelationExpr::union(
330 bindings[index + 1]
331 .1
332 .clone()
333 // The #c1 IS NOT NULL AND ... AND #cn IS NOT NULL filter
334 // ensures that the `Null` keys appearing on the left side
335 // can only match the all `Null` row from additions in the
336 // eventual `product.filter(...)` call.
337 .filter(
338 equations
339 .iter()
340 .map(|(_, r)| MirScalarExpr::column(r - oa - ba).call_is_null().not()),
341 ),
342 additions,
343 );
344
345 // Record the binding we'll need to make for `aug_id`.
346 augmented.push((aug_id, aug_right, aug_value));
347
348 // Update `body` to reflect the product, filtered by `on`.
349 body = product.filter(recompose_equations(equations));
350
351 body = body
352 // Update `body` so that each new column consults its final
353 // column, and if null sets all right columns to null.
354 .map(
355 (oa + ba..oa + ba + ra)
356 .map(|col| MirScalarExpr::If {
357 cond: Box::new(MirScalarExpr::column(oa + ba + ra).call_is_null()),
358 then: Box::new(MirScalarExpr::literal_null(
359 rt[col - (oa + ba)].scalar_type.clone(),
360 )),
361 els: Box::new(MirScalarExpr::column(col)),
362 })
363 .collect(),
364 )
365 // Replace the original |ra + 1| columns with the |ra| columns
366 // produced by the above map(...) call.
367 .project(
368 (0..oa + ba)
369 .chain(oa + ba + ra + 1..oa + ba + ra + 1 + ra)
370 .collect(),
371 );
372
373 ba += ra;
374
375 assert_eq!(oa + ba, body.arity());
376 } else {
377 tracing::debug!(case = 7, index, "attempt_left_join_magic");
378 inc_metrics("voj_7");
379 return Ok(None);
380 }
381 }
382
383 // If we've gotten this for, we've populated `bindings` with various let bindings
384 // we must now create, all wrapped around `body`.
385 while let Some((id, _get, value)) = augmented.pop() {
386 body = MirRelationExpr::Let {
387 id,
388 value: Box::new(value),
389 body: Box::new(body),
390 };
391 }
392 while let Some((id, _get, value)) = bindings.pop() {
393 body = MirRelationExpr::Let {
394 id,
395 value: Box::new(value),
396 body: Box::new(body),
397 };
398 }
399
400 tracing::debug!(case = 0, "attempt_left_join_magic");
401 inc_metrics("voj_0");
402 Ok(Some(body))
403}
404
405use mz_expr::func::variadic::{And, Or};
406use mz_expr::{BinaryFunc, VariadicFunc};
407
408/// If `predicate` can be decomposed as any number of `col(x) = col(y)` expressions anded together, return them.
409/// In order to only find _useful_ equations, one column must be `< lhs_cutoff` and one must be `>= lhs_cutoff`.
410fn decompose_left_to_right_equations(
411 predicate: &MirScalarExpr,
412 lhs_cutoff: usize,
413) -> Option<(Vec<(usize, usize)>, Vec<(usize, usize)>)> {
414 let mut crossing_equations = Vec::new();
415 let mut non_crossing_equations = Vec::new();
416
417 let mut push_equation = |c1: usize, c2: usize| {
418 let l = usize::min(c1, c2);
419 let r = usize::max(c1, c2);
420
421 if l < lhs_cutoff && lhs_cutoff <= r {
422 crossing_equations.push((l, r))
423 } else {
424 non_crossing_equations.push((l, r))
425 }
426 };
427
428 let mut todo = vec![predicate];
429 while let Some(expr) = todo.pop() {
430 match expr {
431 MirScalarExpr::CallVariadic {
432 func: VariadicFunc::And(_),
433 exprs,
434 } => {
435 todo.extend(exprs.iter());
436 }
437 MirScalarExpr::CallBinary {
438 func: BinaryFunc::Eq(_),
439 expr1,
440 expr2,
441 } => {
442 if let (MirScalarExpr::Column(c1, _name1), MirScalarExpr::Column(c2, _name2)) =
443 (&**expr1, &**expr2)
444 {
445 push_equation(*c1, *c2);
446 } else {
447 return None;
448 }
449 }
450 e if e.is_literal_true() => (), // `USING(c1,...,cN)` translates to `true && c1 = c1 ... cN = cN`.
451 _ => return None,
452 }
453 }
454
455 // Remove duplicates
456 crossing_equations.sort();
457 crossing_equations.dedup();
458 non_crossing_equations.sort();
459 non_crossing_equations.dedup();
460
461 // Ensure that every rhs column c2 appears only once. Otherwise, we have at
462 // least two lhs columns c1 and c1' that are rendered equal by the same c2
463 // column. The VOJ lowering will then produce a plan that will incorrectly
464 // push down a local filter c1 = c1' to the lhs (see database-issues#7892).
465 if crossing_equations
466 .iter()
467 .duplicates_by(|(_, c)| c)
468 .next()
469 .is_some()
470 {
471 return None;
472 }
473
474 Some((crossing_equations, non_crossing_equations))
475}
476
477/// Turns column equation into idiomatic Rust equation, where nulls equate.
478fn recompose_equations(pairs: Vec<(usize, usize)>) -> Vec<MirScalarExpr> {
479 pairs
480 .iter()
481 .map(|(x, y)| {
482 MirScalarExpr::call_variadic(
483 Or,
484 vec![
485 MirScalarExpr::column(*x).call_binary(MirScalarExpr::column(*y), func::Eq),
486 MirScalarExpr::call_variadic(
487 And,
488 vec![
489 MirScalarExpr::column(*x).call_is_null(),
490 MirScalarExpr::column(*y).call_is_null(),
491 ],
492 ),
493 ],
494 )
495 })
496 .collect()
497}