Skip to main content

mz_sql/plan/lowering/
variadic_left.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use itertools::Itertools;
11use mz_expr::{MirRelationExpr, MirScalarExpr, func};
12use mz_ore::soft_assert_eq_or_log;
13use mz_repr::Diff;
14
15use crate::plan::PlanError;
16use crate::plan::hir::{HirRelationExpr, HirScalarExpr};
17use crate::plan::lowering::{ColumnMap, Context, CteMap};
18
19/// Attempt to render a stack of left joins as an inner join against "enriched" right relations.
20///
21/// This optimization applies for a contiguous block of left joins where the `right` term is not
22/// correlated, and where the `on` constraints equate columns in `right` to expressions over some
23/// single prior joined relation (`left`, or a prior `right`).
24///
25/// The plan is to enrich each `right` with any missing key values, extracted by applying the equated
26/// expressions to the source collection and then introducing them to an "augmented" right relation.
27/// The introduced records are augmented with null values where missing, and an additional column that
28/// indicates whether the data are original or augmented (important for masking out introduced keys).
29///
30/// Importantly, we need to introduce the constraints that equate columns and expressions in the `Join`,
31/// as a `Filter` will still use SQL's equality, which treats NULL as unequal (we want them to match).
32/// We could replace each `(col = expr)` with `(col = expr OR (col IS NULL AND expr IS NULL))`.
33pub(crate) fn attempt_left_join_magic(
34    left: &HirRelationExpr,
35    rights: Vec<(&HirRelationExpr, &HirScalarExpr)>,
36    id_gen: &mut mz_ore::id_gen::IdGen,
37    get_outer: MirRelationExpr,
38    col_map: &ColumnMap,
39    cte_map: &mut CteMap,
40    context: &Context,
41) -> Result<Option<MirRelationExpr>, PlanError> {
42    use mz_expr::LocalId;
43
44    let inc_metrics = |case: &str| {
45        if let Some(metrics) = context.metrics {
46            metrics.inc_outer_join_lowering(case);
47        }
48    };
49
50    let oa = get_outer.arity();
51    tracing::debug!(
52        inputs = rights.len() + 1,
53        outer_arity = oa,
54        "attempt_left_join_magic"
55    );
56
57    if oa > 0 {
58        // Bail out in correlated contexts for now. Even though the code below
59        // supports them, we want to test this code path more thoroughly before
60        // enabling this.
61        tracing::debug!(case = 1, oa, "attempt_left_join_magic");
62        inc_metrics("voj_1");
63        return Ok(None);
64    }
65
66    // Will contain a list of let binding obligations.
67    // We may modify the values if we find promising prior values.
68    let mut bindings = Vec::new();
69    let mut augmented = Vec::new();
70    // A vector associating result columns with their corresponding input number
71    // (where 0 indicates columns from the outer context).
72    let mut bound_to = (0..oa).map(|_| 0).collect::<Vec<_>>();
73    // A vector associating inputs with their arities (where the [0] entry
74    // corresponds to the arity of the outer context).
75    let mut arities = vec![oa];
76
77    // Left relation, its type, and its arity.
78    let left = left
79        .clone()
80        .applied_to(id_gen, get_outer.clone(), col_map, cte_map, context)?;
81    let full_left_typ = left.typ();
82    let lt = full_left_typ
83        .column_types
84        .iter()
85        .skip(oa)
86        .cloned()
87        .collect_vec();
88    let la = lt.len();
89
90    // Create a new let binding to use as input.
91    // We may use these relations multiple times to extract augmenting values.
92    let id = LocalId::new(id_gen.allocate_id());
93    // The join body that we will iteratively develop.
94    let mut body = MirRelationExpr::local_get(id, full_left_typ.clone());
95    bindings.push((id, body.clone(), left));
96    bound_to.extend((0..la).map(|_| 1));
97    arities.push(la);
98
99    // "body arity": number of columns in `body`; the join we are building.
100    let mut ba = la;
101
102    // For each LEFT JOIN, there is a `right` input and an `on` constraint.
103    // We want to decorrelate them, failing if there are subqueries because omg no,
104    // and then check to see if the decorrelated `on` equates RHS columns with values
105    // in one prior input. If so; bring those values into the mix, and bind that as
106    // the value of the `Let` binding.
107    for (index, (right, on)) in rights.into_iter().rev().enumerate() {
108        // Correlated right expressions are handled in a different branch than standard
109        // outer join lowering, and I don't know what they mean. Fail conservatively.
110        if right.is_correlated() {
111            tracing::debug!(case = 2, index, "attempt_left_join_magic");
112            inc_metrics("voj_2");
113            return Ok(None);
114        }
115
116        // Decorrelate `right`.
117        let right_col_map = col_map.enter_scope(0);
118        let right = right
119            .clone()
120            .map(vec![HirScalarExpr::literal_true()]) // add a bit to mark "real" rows.
121            .applied_to(id_gen, get_outer.clone(), &right_col_map, cte_map, context)?;
122        let full_right_typ = right.typ();
123        let rt = full_right_typ
124            .column_types
125            .iter()
126            .skip(oa)
127            .cloned()
128            .collect_vec();
129        let ra = rt.len() - 1; // don't count the new column
130
131        let mut right_type = full_right_typ;
132        // Create a binding for `right`, unadulterated.
133        let id = LocalId::new(id_gen.allocate_id());
134        let get_right = MirRelationExpr::local_get(id, right_type.clone());
135        // Create a binding for the augmented right, which we will form here but use before we do.
136        // We want the join to be based off of the augmented relation, but we don't yet know how
137        // to augment it until we decorrelate `on`. So, we use a `Get` binding that we backfill.
138        for column in right_type.column_types.iter_mut() {
139            column.nullable = true;
140        }
141        right_type.keys.clear();
142        let aug_id = LocalId::new(id_gen.allocate_id());
143        let aug_right = MirRelationExpr::local_get(aug_id, right_type.clone());
144
145        bindings.push((id, get_right.clone(), right));
146        bound_to.extend((0..ra).map(|_| 2 + index));
147        arities.push(ra);
148
149        // Cartesian join but equating the outer columns.
150        let mut product = MirRelationExpr::join(
151            vec![body, aug_right.clone()],
152            (0..oa).map(|i| vec![(0, i), (1, i)]).collect(),
153        )
154        // ... remove the second copy of the outer columns.
155        .project(
156            (0..(oa + ba))
157                .chain((oa + ba + oa)..(oa + ba + oa + ra + 1)) // include new column
158                .collect(),
159        );
160
161        // Decorrelate and lower the `on` clause.
162        let on = on
163            .clone()
164            .applied_to(id_gen, col_map, cte_map, &mut product, &None, context)?;
165
166        // if `on` added any new columns, .. no clue what to do.
167        // Return with failure, to avoid any confusion.
168        if product.arity() > oa + ba + ra + 1 {
169            tracing::debug!(case = 3, index, "attempt_left_join_magic");
170            inc_metrics("voj_3");
171            return Ok(None);
172        }
173
174        // If `on` equates columns in `right` with columns in some input,
175        // not just "any columns in `body`" but some single specific input,
176        // then we can fish out values from that input. If it equates values
177        // across multiple inputs, we would need to fish out valid tuples and
178        // no idea how we would get those w/o doing a join or a cartesian product.
179        let (equations, non_crossing_equations) =
180            if let Some(list) = decompose_left_to_right_equations(&on, oa + ba) {
181                list
182            } else {
183                tracing::debug!(case = 4, index, "attempt_left_join_magic");
184                inc_metrics("voj_4");
185                return Ok(None);
186            };
187
188        if !non_crossing_equations.is_empty() {
189            // TODO(mgree) This case isn't _impossible_, but it's complicated.
190            // We have equations that cross from left to right, but we also have
191            // left-left or right-right equations. Making sure we get exactly the
192            // right results here is hard enough that we don't attempt it.
193            tracing::debug!(case = 8, index, "attempt_left_join_magic");
194            inc_metrics("voj_8");
195            return Ok(None);
196        }
197
198        // We now need to see if all left columns exist in some input relation,
199        // and that all right columns are actually in the right relation. Idk.
200        // Left columns less than `oa` do not bind to an input, as they are for
201        // columns present in all inputs.
202        let mut bound_input = None;
203        for (left, right) in equations.iter().cloned() {
204            // If the right reference is not actually to `right`, bail out.
205            if right < oa + ba {
206                tracing::debug!(case = 5, index, "attempt_left_join_magic");
207                inc_metrics("voj_5");
208                return Ok(None);
209            }
210            // Only columns not from the outer scope introduce bindings (`oa <= left`)
211            // And `left` needs to be a column in the left relation (`left < oa + ba`)
212            if oa <= left && left < oa + ba {
213                if let Some(bound) = bound_input {
214                    // If left references come from different inputs, bail out.
215                    if bound_to[left] != bound {
216                        tracing::debug!(case = 6, index, "attempt_left_join_magic");
217                        inc_metrics("voj_6");
218                        return Ok(None);
219                    }
220                }
221                bound_input = Some(bound_to[left]);
222            }
223        }
224
225        if let Some(bound) = bound_input {
226            // This is great news; we have an input `bound` that we can augment,
227            // and just need to pull those values in to the definition of `right`.
228
229            // Add up prior arities, to learn what to subtract from left references.
230            // Don't subtract anything from left references less than `oa`!
231            let offset: usize = arities[0..bound].iter().sum();
232
233            // We now want to grab the `Get` for both left and right relations,
234            // which we will project to get distinct values, then difference and
235            // threshold to find those present in left but missing in right.
236            let get_left = &bindings[bound - 1].1;
237            // Set up a type for the all-nulls row we need to introduce.
238            let mut left_typ = get_left.typ();
239            for col in left_typ.column_types.iter_mut() {
240                col.nullable = true;
241            }
242            left_typ.keys.clear();
243            // `get_right` is already bound.
244
245            // Augment left_vals an all `Null` row, so that any null values
246            // match with nulls, and compute the distinct join keys in the
247            // resulting union.
248            let left_vals = MirRelationExpr::union(
249                get_left.clone(),
250                MirRelationExpr::Constant {
251                    rows: Ok(vec![(
252                        mz_repr::Row::pack(
253                            std::iter::repeat(mz_repr::Datum::Null).take(left_typ.arity()),
254                        ),
255                        Diff::ONE,
256                    )]),
257                    typ: left_typ.clone(),
258                },
259            )
260            .project(
261                equations
262                    .iter()
263                    .map(|(l, _)| if l < &oa { *l } else { l - offset })
264                    .collect::<Vec<_>>(),
265            )
266            .distinct();
267
268            // Compute the non-Null join keys on the right side. We skip the
269            // distinct because the eventual `threshold` between `left_vals` and
270            // `right_vals` protects us.
271            let right_vals = get_right
272                .clone()
273                // The #c1 IS NOT NULL AND ... AND #cn IS NOT NULL filter
274                // ensures that we won't remove the all `Null` row in the
275                // eventual `threshold` call.
276                .filter(
277                    equations
278                        .iter()
279                        .map(|(_, r)| MirScalarExpr::column(r - oa - ba).call_is_null().not()),
280                )
281                // Retain only the keys referenced on the right side of the LEFT
282                // JOIN equations.
283                .project(
284                    equations
285                        .iter()
286                        .map(|(_, r)| r - oa - ba)
287                        .collect::<Vec<_>>(),
288                );
289
290            // Now we need to permute them into place, and leave `Datum::Null` values behind.
291            let additions = MirRelationExpr::union(right_vals.negate(), left_vals)
292                .threshold()
293                .map(
294                    // Append nulls for all get_right columns, including the
295                    // extra column at the end that is used to differentiate between
296                    // augmented and original columns in the aug_value.
297                    rt.iter()
298                        .map(|t| MirScalarExpr::literal_null(t.scalar_type.clone()))
299                        .collect::<Vec<_>>(),
300                )
301                .project({
302                    // By default, we'll place post-pended nulls in each location.
303                    // We will overwrite this with instructions to find augmenting values.
304
305                    // Start with a projection that retains the last |rt|
306                    // columns corresponding to the NULLs from the above
307                    // .map(...) call.
308                    let mut projection =
309                        (equations.len()..equations.len() + rt.len()).collect::<Vec<_>>();
310                    // Replace NULLs columns corresponding to rhs columns
311                    // referenced in an ON equation with the actual rhs value
312                    // (located at `index`).
313                    for (index, (_, right)) in equations.iter().enumerate() {
314                        projection[*right - oa - ba] = index;
315                    }
316
317                    projection
318                });
319
320            // This is where we should add a boolean column to indicate that the row is augmented,
321            // so that after the join is done we can overwrite all values for `right` with null values.
322            // This is a quirk of how outer joins work: the matched columns are left as null.
323
324            // TODO(aalexandrov): if we never see an error from this we can
325            // 1. Use `get_right` instead of `bindings[index + 1].1.clone()`.
326            // 2. Simplify bindings to use tuples instead of triples.
327            soft_assert_eq_or_log!(&bindings[index + 1].1, &get_right);
328
329            let aug_value = MirRelationExpr::union(
330                bindings[index + 1]
331                    .1
332                    .clone()
333                    // The #c1 IS NOT NULL AND ... AND #cn IS NOT NULL filter
334                    // ensures that the `Null` keys appearing on the left side
335                    // can only match the all `Null` row from additions in the
336                    // eventual `product.filter(...)` call.
337                    .filter(
338                        equations
339                            .iter()
340                            .map(|(_, r)| MirScalarExpr::column(r - oa - ba).call_is_null().not()),
341                    ),
342                additions,
343            );
344
345            // Record the binding we'll need to make for `aug_id`.
346            augmented.push((aug_id, aug_right, aug_value));
347
348            // Update `body` to reflect the product, filtered by `on`.
349            body = product.filter(recompose_equations(equations));
350
351            body = body
352                // Update `body` so that each new column consults its final
353                // column, and if null sets all right columns to null.
354                .map(
355                    (oa + ba..oa + ba + ra)
356                        .map(|col| MirScalarExpr::If {
357                            cond: Box::new(MirScalarExpr::column(oa + ba + ra).call_is_null()),
358                            then: Box::new(MirScalarExpr::literal_null(
359                                rt[col - (oa + ba)].scalar_type.clone(),
360                            )),
361                            els: Box::new(MirScalarExpr::column(col)),
362                        })
363                        .collect(),
364                )
365                // Replace the original |ra + 1| columns with the |ra| columns
366                // produced by the above map(...) call.
367                .project(
368                    (0..oa + ba)
369                        .chain(oa + ba + ra + 1..oa + ba + ra + 1 + ra)
370                        .collect(),
371                );
372
373            ba += ra;
374
375            assert_eq!(oa + ba, body.arity());
376        } else {
377            tracing::debug!(case = 7, index, "attempt_left_join_magic");
378            inc_metrics("voj_7");
379            return Ok(None);
380        }
381    }
382
383    // If we've gotten this for, we've populated `bindings` with various let bindings
384    // we must now create, all wrapped around `body`.
385    while let Some((id, _get, value)) = augmented.pop() {
386        body = MirRelationExpr::Let {
387            id,
388            value: Box::new(value),
389            body: Box::new(body),
390        };
391    }
392    while let Some((id, _get, value)) = bindings.pop() {
393        body = MirRelationExpr::Let {
394            id,
395            value: Box::new(value),
396            body: Box::new(body),
397        };
398    }
399
400    tracing::debug!(case = 0, "attempt_left_join_magic");
401    inc_metrics("voj_0");
402    Ok(Some(body))
403}
404
405use mz_expr::func::variadic::{And, Or};
406use mz_expr::{BinaryFunc, VariadicFunc};
407
408/// If `predicate` can be decomposed as any number of `col(x) = col(y)` expressions anded together, return them.
409/// In order to only find _useful_ equations, one column must be `< lhs_cutoff` and one must be `>= lhs_cutoff`.
410fn decompose_left_to_right_equations(
411    predicate: &MirScalarExpr,
412    lhs_cutoff: usize,
413) -> Option<(Vec<(usize, usize)>, Vec<(usize, usize)>)> {
414    let mut crossing_equations = Vec::new();
415    let mut non_crossing_equations = Vec::new();
416
417    let mut push_equation = |c1: usize, c2: usize| {
418        let l = usize::min(c1, c2);
419        let r = usize::max(c1, c2);
420
421        if l < lhs_cutoff && lhs_cutoff <= r {
422            crossing_equations.push((l, r))
423        } else {
424            non_crossing_equations.push((l, r))
425        }
426    };
427
428    let mut todo = vec![predicate];
429    while let Some(expr) = todo.pop() {
430        match expr {
431            MirScalarExpr::CallVariadic {
432                func: VariadicFunc::And(_),
433                exprs,
434            } => {
435                todo.extend(exprs.iter());
436            }
437            MirScalarExpr::CallBinary {
438                func: BinaryFunc::Eq(_),
439                expr1,
440                expr2,
441            } => {
442                if let (MirScalarExpr::Column(c1, _name1), MirScalarExpr::Column(c2, _name2)) =
443                    (&**expr1, &**expr2)
444                {
445                    push_equation(*c1, *c2);
446                } else {
447                    return None;
448                }
449            }
450            e if e.is_literal_true() => (), // `USING(c1,...,cN)` translates to `true && c1 = c1 ... cN = cN`.
451            _ => return None,
452        }
453    }
454
455    // Remove duplicates
456    crossing_equations.sort();
457    crossing_equations.dedup();
458    non_crossing_equations.sort();
459    non_crossing_equations.dedup();
460
461    // Ensure that every rhs column c2 appears only once. Otherwise, we have at
462    // least two lhs columns c1 and c1' that are rendered equal by the same c2
463    // column. The VOJ lowering will then produce a plan that will incorrectly
464    // push down a local filter c1 = c1' to the lhs (see database-issues#7892).
465    if crossing_equations
466        .iter()
467        .duplicates_by(|(_, c)| c)
468        .next()
469        .is_some()
470    {
471        return None;
472    }
473
474    Some((crossing_equations, non_crossing_equations))
475}
476
477/// Turns column equation into idiomatic Rust equation, where nulls equate.
478fn recompose_equations(pairs: Vec<(usize, usize)>) -> Vec<MirScalarExpr> {
479    pairs
480        .iter()
481        .map(|(x, y)| {
482            MirScalarExpr::call_variadic(
483                Or,
484                vec![
485                    MirScalarExpr::column(*x).call_binary(MirScalarExpr::column(*y), func::Eq),
486                    MirScalarExpr::call_variadic(
487                        And,
488                        vec![
489                            MirScalarExpr::column(*x).call_is_null(),
490                            MirScalarExpr::column(*y).call_is_null(),
491                        ],
492                    ),
493                ],
494            )
495        })
496        .collect()
497}