mz_sql/plan/lowering/variadic_left.rs
1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use itertools::Itertools;
11use mz_expr::{MirRelationExpr, MirScalarExpr};
12use mz_ore::soft_assert_eq_or_log;
13use mz_repr::Diff;
14
15use crate::plan::PlanError;
16use crate::plan::hir::{HirRelationExpr, HirScalarExpr};
17use crate::plan::lowering::{ColumnMap, Context, CteMap};
18
19/// Attempt to render a stack of left joins as an inner join against "enriched" right relations.
20///
21/// This optimization applies for a contiguous block of left joins where the `right` term is not
22/// correlated, and where the `on` constraints equate columns in `right` to expressions over some
23/// single prior joined relation (`left`, or a prior `right`).
24///
25/// The plan is to enrich each `right` with any missing key values, extracted by applying the equated
26/// expressions to the source collection and then introducing them to an "augmented" right relation.
27/// The introduced records are augmented with null values where missing, and an additional column that
28/// indicates whether the data are original or augmented (important for masking out introduced keys).
29///
30/// Importantly, we need to introduce the constraints that equate columns and expressions in the `Join`,
31/// as a `Filter` will still use SQL's equality, which treats NULL as unequal (we want them to match).
32/// We could replace each `(col = expr)` with `(col = expr OR (col IS NULL AND expr IS NULL))`.
33pub(crate) fn attempt_left_join_magic(
34 left: &HirRelationExpr,
35 rights: Vec<(&HirRelationExpr, &HirScalarExpr)>,
36 id_gen: &mut mz_ore::id_gen::IdGen,
37 get_outer: MirRelationExpr,
38 col_map: &ColumnMap,
39 cte_map: &mut CteMap,
40 context: &Context,
41) -> Result<Option<MirRelationExpr>, PlanError> {
42 use mz_expr::LocalId;
43
44 let inc_metrics = |case: &str| {
45 if let Some(metrics) = context.metrics {
46 metrics.inc_outer_join_lowering(case);
47 }
48 };
49
50 let oa = get_outer.arity();
51 tracing::debug!(
52 inputs = rights.len() + 1,
53 outer_arity = oa,
54 "attempt_left_join_magic"
55 );
56
57 if oa > 0 {
58 // Bail out in correlated contexts for now. Even though the code below
59 // supports them, we want to test this code path more thoroughly before
60 // enabling this.
61 tracing::debug!(case = 1, oa, "attempt_left_join_magic");
62 inc_metrics("voj_1");
63 return Ok(None);
64 }
65
66 // Will contain a list of let binding obligations.
67 // We may modify the values if we find promising prior values.
68 let mut bindings = Vec::new();
69 let mut augmented = Vec::new();
70 // A vector associating result columns with their corresponding input number
71 // (where 0 indicates columns from the outer context).
72 let mut bound_to = (0..oa).map(|_| 0).collect::<Vec<_>>();
73 // A vector associating inputs with their arities (where the [0] entry
74 // corresponds to the arity of the outer context).
75 let mut arities = vec![oa];
76
77 // Left relation, its type, and its arity.
78 let left = left
79 .clone()
80 .applied_to(id_gen, get_outer.clone(), col_map, cte_map, context)?;
81 let full_left_typ = left.typ();
82 let lt = full_left_typ
83 .column_types
84 .iter()
85 .skip(oa)
86 .cloned()
87 .collect_vec();
88 let la = lt.len();
89
90 // Create a new let binding to use as input.
91 // We may use these relations multiple times to extract augmenting values.
92 let id = LocalId::new(id_gen.allocate_id());
93 // The join body that we will iteratively develop.
94 let mut body = MirRelationExpr::local_get(id, full_left_typ);
95 bindings.push((id, body.clone(), left));
96 bound_to.extend((0..la).map(|_| 1));
97 arities.push(la);
98
99 // "body arity": number of columns in `body`; the join we are building.
100 let mut ba = la;
101
102 // For each LEFT JOIN, there is a `right` input and an `on` constraint.
103 // We want to decorrelate them, failing if there are subqueries because omg no,
104 // and then check to see if the decorrelated `on` equates RHS columns with values
105 // in one prior input. If so; bring those values into the mix, and bind that as
106 // the value of the `Let` binding.
107 for (index, (right, on)) in rights.into_iter().rev().enumerate() {
108 // Correlated right expressions are handled in a different branch than standard
109 // outer join lowering, and I don't know what they mean. Fail conservatively.
110 if right.is_correlated() {
111 tracing::debug!(case = 2, index, "attempt_left_join_magic");
112 inc_metrics("voj_2");
113 return Ok(None);
114 }
115
116 // Decorrelate `right`.
117 let right_col_map = col_map.enter_scope(0);
118 let right = right
119 .clone()
120 .map(vec![HirScalarExpr::literal_true()]) // add a bit to mark "real" rows.
121 .applied_to(id_gen, get_outer.clone(), &right_col_map, cte_map, context)?;
122 let full_right_typ = right.typ();
123 let rt = full_right_typ
124 .column_types
125 .iter()
126 .skip(oa)
127 .cloned()
128 .collect_vec();
129 let ra = rt.len() - 1; // don't count the new column
130
131 let mut right_type = full_right_typ;
132 // Create a binding for `right`, unadulterated.
133 let id = LocalId::new(id_gen.allocate_id());
134 let get_right = MirRelationExpr::local_get(id, right_type.clone());
135 // Create a binding for the augmented right, which we will form here but use before we do.
136 // We want the join to be based off of the augmented relation, but we don't yet know how
137 // to augment it until we decorrelate `on`. So, we use a `Get` binding that we backfill.
138 for column in right_type.column_types.iter_mut() {
139 column.nullable = true;
140 }
141 right_type.keys.clear();
142 let aug_id = LocalId::new(id_gen.allocate_id());
143 let aug_right = MirRelationExpr::local_get(aug_id, right_type);
144
145 bindings.push((id, get_right.clone(), right));
146 bound_to.extend((0..ra).map(|_| 2 + index));
147 arities.push(ra);
148
149 // Cartesian join but equating the outer columns.
150 let mut product = MirRelationExpr::join(
151 vec![body, aug_right.clone()],
152 (0..oa).map(|i| vec![(0, i), (1, i)]).collect(),
153 )
154 // ... remove the second copy of the outer columns.
155 .project(
156 (0..(oa + ba))
157 .chain((oa + ba + oa)..(oa + ba + oa + ra + 1)) // include new column
158 .collect(),
159 );
160
161 // Decorrelate and lower the `on` clause.
162 let on = on
163 .clone()
164 .applied_to(id_gen, col_map, cte_map, &mut product, &None, context)?;
165
166 // if `on` added any new columns, .. no clue what to do.
167 // Return with failure, to avoid any confusion.
168 if product.arity() > oa + ba + ra + 1 {
169 tracing::debug!(case = 3, index, "attempt_left_join_magic");
170 inc_metrics("voj_3");
171 return Ok(None);
172 }
173
174 // If `on` equates columns in `right` with columns in some input,
175 // not just "any columns in `body`" but some single specific input,
176 // then we can fish out values from that input. If it equates values
177 // across multiple inputs, we would need to fish out valid tuples and
178 // no idea how we would get those w/o doing a join or a cartesian product.
179 let equations = if let Some(list) = decompose_equations(&on) {
180 list
181 } else {
182 tracing::debug!(case = 4, index, "attempt_left_join_magic");
183 inc_metrics("voj_4");
184 return Ok(None);
185 };
186
187 // We now need to see if all left columns exist in some input relation,
188 // and that all right columns are actually in the right relation. Idk.
189 // Left columns less than `oa` do not bind to an input, as they are for
190 // columns present in all inputs.
191 let mut bound_input = None;
192 for (left, right) in equations.iter().cloned() {
193 // If the right reference is not actually to `right`, bail out.
194 if right < oa + ba {
195 tracing::debug!(case = 5, index, "attempt_left_join_magic");
196 inc_metrics("voj_5");
197 return Ok(None);
198 }
199 // Only columns not from the outer scope introduce bindings.
200 if left >= oa {
201 if let Some(bound) = bound_input {
202 // If left references come from different inputs, bail out.
203 if bound_to[left] != bound {
204 tracing::debug!(case = 6, index, "attempt_left_join_magic");
205 inc_metrics("voj_6");
206 return Ok(None);
207 }
208 }
209 bound_input = Some(bound_to[left]);
210 }
211 }
212
213 if let Some(bound) = bound_input {
214 // This is great news; we have an input `bound` that we can augment,
215 // and just need to pull those values in to the definition of `right`.
216
217 // Add up prior arities, to learn what to subtract from left references.
218 // Don't subtract anything from left references less than `oa`!
219 let offset: usize = arities[0..bound].iter().sum();
220
221 // We now want to grab the `Get` for both left and right relations,
222 // which we will project to get distinct values, then difference and
223 // threshold to find those present in left but missing in right.
224 let get_left = &bindings[bound - 1].1;
225 // Set up a type for the all-nulls row we need to introduce.
226 let mut left_typ = get_left.typ();
227 for col in left_typ.column_types.iter_mut() {
228 col.nullable = true;
229 }
230 left_typ.keys.clear();
231 // `get_right` is already bound.
232
233 // Augment left_vals an all `Null` row, so that any null values
234 // match with nulls, and compute the distinct join keys in the
235 // resulting union.
236 let left_vals = MirRelationExpr::union(
237 get_left.clone(),
238 MirRelationExpr::Constant {
239 rows: Ok(vec![(
240 mz_repr::Row::pack(
241 std::iter::repeat(mz_repr::Datum::Null).take(left_typ.arity()),
242 ),
243 Diff::ONE,
244 )]),
245 typ: left_typ,
246 },
247 )
248 .project(
249 equations
250 .iter()
251 .map(|(l, _)| if l < &oa { *l } else { l - offset })
252 .collect::<Vec<_>>(),
253 )
254 .distinct();
255
256 // Compute the non-Null join keys on the right side. We skip the
257 // distinct because the eventual `threshold` between `left_vals` and
258 // `right_vals` protects us.
259 let right_vals = get_right
260 .clone()
261 // The #c1 IS NOT NULL AND ... AND #cn IS NOT NULL filter
262 // ensures that we won't remove the all `Null` row in the
263 // eventual `threshold` call.
264 .filter(
265 equations
266 .iter()
267 .map(|(_, r)| MirScalarExpr::column(r - oa - ba).call_is_null().not()),
268 )
269 // Retain only the keys referenced on the right side of the LEFT
270 // JOIN equations.
271 .project(
272 equations
273 .iter()
274 .map(|(_, r)| r - oa - ba)
275 .collect::<Vec<_>>(),
276 );
277
278 // Now we need to permute them into place, and leave `Datum::Null` values behind.
279 let additions = MirRelationExpr::union(right_vals.negate(), left_vals)
280 .threshold()
281 .map(
282 // Append nulls for all get_right columns, including the
283 // extra column at the end that is used to differentiate between
284 // augmented and original columns in the aug_value.
285 rt.iter()
286 .map(|t| MirScalarExpr::literal_null(t.scalar_type.clone()))
287 .collect::<Vec<_>>(),
288 )
289 .project({
290 // By default, we'll place post-pended nulls in each location.
291 // We will overwrite this with instructions to find augmenting values.
292
293 // Start with a projection that retains the last |rt|
294 // columns corresponding to the NULLs from the above
295 // .map(...) call.
296 let mut projection =
297 (equations.len()..equations.len() + rt.len()).collect::<Vec<_>>();
298 // Replace NULLs columns corresponding to rhs columns
299 // referenced in an ON equation with the actual rhs value
300 // (located at `index`).
301 for (index, (_, right)) in equations.iter().enumerate() {
302 projection[*right - oa - ba] = index;
303 }
304
305 projection
306 });
307
308 // This is where we should add a boolean column to indicate that the row is augmented,
309 // so that after the join is done we can overwrite all values for `right` with null values.
310 // This is a quirk of how outer joins work: the matched columns are left as null.
311
312 // TODO(aalexandrov): if we never see an error from this we can
313 // 1. Use `get_right` instead of `bindings[index + 1].1.clone()`.
314 // 2. Simplify bindings to use tuples instead of triples.
315 soft_assert_eq_or_log!(&bindings[index + 1].1, &get_right);
316
317 let aug_value = MirRelationExpr::union(
318 bindings[index + 1]
319 .1
320 .clone()
321 // The #c1 IS NOT NULL AND ... AND #cn IS NOT NULL filter
322 // ensures that the `Null` keys appearing on the left side
323 // can only match the all `Null` row from additions in the
324 // eventual `product.filter(...)` call.
325 .filter(
326 equations
327 .iter()
328 .map(|(_, r)| MirScalarExpr::column(r - oa - ba).call_is_null().not()),
329 ),
330 additions,
331 );
332
333 // Record the binding we'll need to make for `aug_id`.
334 augmented.push((aug_id, aug_right, aug_value));
335
336 // Update `body` to reflect the product, filtered by `on`.
337 body = product.filter(recompose_equations(equations));
338
339 body = body
340 // Update `body` so that each new column consults its final
341 // column, and if null sets all right columns to null.
342 .map(
343 (oa + ba..oa + ba + ra)
344 .map(|col| MirScalarExpr::If {
345 cond: Box::new(MirScalarExpr::Column(oa + ba + ra).call_is_null()),
346 then: Box::new(MirScalarExpr::literal_null(
347 rt[col - (oa + ba)].scalar_type.clone(),
348 )),
349 els: Box::new(MirScalarExpr::Column(col)),
350 })
351 .collect(),
352 )
353 // Replace the original |ra + 1| columns with the |ra| columns
354 // produced by the above map(...) call.
355 .project(
356 (0..oa + ba)
357 .chain(oa + ba + ra + 1..oa + ba + ra + 1 + ra)
358 .collect(),
359 );
360
361 ba += ra;
362
363 assert_eq!(oa + ba, body.arity());
364 } else {
365 tracing::debug!(case = 7, index, "attempt_left_join_magic");
366 inc_metrics("voj_7");
367 return Ok(None);
368 }
369 }
370
371 // If we've gotten this for, we've populated `bindings` with various let bindings
372 // we must now create, all wrapped around `body`.
373 while let Some((id, _get, value)) = augmented.pop() {
374 body = MirRelationExpr::Let {
375 id,
376 value: Box::new(value),
377 body: Box::new(body),
378 };
379 }
380 while let Some((id, _get, value)) = bindings.pop() {
381 body = MirRelationExpr::Let {
382 id,
383 value: Box::new(value),
384 body: Box::new(body),
385 };
386 }
387
388 tracing::debug!(case = 0, "attempt_left_join_magic");
389 inc_metrics("voj_0");
390 Ok(Some(body))
391}
392
393use mz_expr::{BinaryFunc, VariadicFunc};
394
395/// If `predicate` can be decomposed as any number of `col(x) = col(y)` expressions anded together, return them.
396fn decompose_equations(predicate: &MirScalarExpr) -> Option<Vec<(usize, usize)>> {
397 let mut equations = Vec::new();
398
399 let mut todo = vec![predicate];
400 while let Some(expr) = todo.pop() {
401 match expr {
402 MirScalarExpr::CallVariadic {
403 func: VariadicFunc::And,
404 exprs,
405 } => {
406 todo.extend(exprs.iter());
407 }
408 MirScalarExpr::CallBinary {
409 func: BinaryFunc::Eq,
410 expr1,
411 expr2,
412 } => {
413 if let (MirScalarExpr::Column(c1), MirScalarExpr::Column(c2)) = (&**expr1, &**expr2)
414 {
415 if c1 < c2 {
416 equations.push((*c1, *c2));
417 } else {
418 equations.push((*c2, *c1));
419 }
420 } else {
421 return None;
422 }
423 }
424 e if e.is_literal_true() => (), // `USING(c1,...,cN)` translates to `true && c1 = c1 ... cN = cN`.
425 _ => return None,
426 }
427 }
428
429 // Remove duplicates
430 equations.sort();
431 equations.dedup();
432
433 // Ensure that every rhs column c2 appears only once. Otherwise, we have at
434 // least two lhs columns c1 and c1' that are rendered equal by the same c2
435 // column. The VOJ lowering will then produce a plan that will incorrectly
436 // push down a local filter c1 = c1' to the lhs (see database-issues#7892).
437 if equations.iter().duplicates_by(|(_, c)| c).next().is_some() {
438 return None;
439 }
440
441 Some(equations)
442}
443
444/// Turns column equation into idiomatic Rust equation, where nulls equate.
445fn recompose_equations(pairs: Vec<(usize, usize)>) -> Vec<MirScalarExpr> {
446 pairs
447 .iter()
448 .map(|(x, y)| MirScalarExpr::CallVariadic {
449 func: VariadicFunc::Or,
450 exprs: vec![
451 MirScalarExpr::CallBinary {
452 func: BinaryFunc::Eq,
453 expr1: Box::new(MirScalarExpr::Column(*x)),
454 expr2: Box::new(MirScalarExpr::Column(*y)),
455 },
456 MirScalarExpr::CallVariadic {
457 func: VariadicFunc::And,
458 exprs: vec![
459 MirScalarExpr::Column(*x).call_is_null(),
460 MirScalarExpr::Column(*y).call_is_null(),
461 ],
462 },
463 ],
464 })
465 .collect()
466}