mz_sql/plan/lowering/
variadic_left.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use itertools::Itertools;
11use mz_expr::{MirRelationExpr, MirScalarExpr};
12use mz_ore::soft_assert_eq_or_log;
13use mz_repr::Diff;
14
15use crate::plan::PlanError;
16use crate::plan::hir::{HirRelationExpr, HirScalarExpr};
17use crate::plan::lowering::{ColumnMap, Context, CteMap};
18
19/// Attempt to render a stack of left joins as an inner join against "enriched" right relations.
20///
21/// This optimization applies for a contiguous block of left joins where the `right` term is not
22/// correlated, and where the `on` constraints equate columns in `right` to expressions over some
23/// single prior joined relation (`left`, or a prior `right`).
24///
25/// The plan is to enrich each `right` with any missing key values, extracted by applying the equated
26/// expressions to the source collection and then introducing them to an "augmented" right relation.
27/// The introduced records are augmented with null values where missing, and an additional column that
28/// indicates whether the data are original or augmented (important for masking out introduced keys).
29///
30/// Importantly, we need to introduce the constraints that equate columns and expressions in the `Join`,
31/// as a `Filter` will still use SQL's equality, which treats NULL as unequal (we want them to match).
32/// We could replace each `(col = expr)` with `(col = expr OR (col IS NULL AND expr IS NULL))`.
33pub(crate) fn attempt_left_join_magic(
34    left: &HirRelationExpr,
35    rights: Vec<(&HirRelationExpr, &HirScalarExpr)>,
36    id_gen: &mut mz_ore::id_gen::IdGen,
37    get_outer: MirRelationExpr,
38    col_map: &ColumnMap,
39    cte_map: &mut CteMap,
40    context: &Context,
41) -> Result<Option<MirRelationExpr>, PlanError> {
42    use mz_expr::LocalId;
43
44    let inc_metrics = |case: &str| {
45        if let Some(metrics) = context.metrics {
46            metrics.inc_outer_join_lowering(case);
47        }
48    };
49
50    let oa = get_outer.arity();
51    tracing::debug!(
52        inputs = rights.len() + 1,
53        outer_arity = oa,
54        "attempt_left_join_magic"
55    );
56
57    if oa > 0 {
58        // Bail out in correlated contexts for now. Even though the code below
59        // supports them, we want to test this code path more thoroughly before
60        // enabling this.
61        tracing::debug!(case = 1, oa, "attempt_left_join_magic");
62        inc_metrics("voj_1");
63        return Ok(None);
64    }
65
66    // Will contain a list of let binding obligations.
67    // We may modify the values if we find promising prior values.
68    let mut bindings = Vec::new();
69    let mut augmented = Vec::new();
70    // A vector associating result columns with their corresponding input number
71    // (where 0 indicates columns from the outer context).
72    let mut bound_to = (0..oa).map(|_| 0).collect::<Vec<_>>();
73    // A vector associating inputs with their arities (where the [0] entry
74    // corresponds to the arity of the outer context).
75    let mut arities = vec![oa];
76
77    // Left relation, its type, and its arity.
78    let left = left
79        .clone()
80        .applied_to(id_gen, get_outer.clone(), col_map, cte_map, context)?;
81    let full_left_typ = left.typ();
82    let lt = full_left_typ
83        .column_types
84        .iter()
85        .skip(oa)
86        .cloned()
87        .collect_vec();
88    let la = lt.len();
89
90    // Create a new let binding to use as input.
91    // We may use these relations multiple times to extract augmenting values.
92    let id = LocalId::new(id_gen.allocate_id());
93    // The join body that we will iteratively develop.
94    let mut body = MirRelationExpr::local_get(id, full_left_typ);
95    bindings.push((id, body.clone(), left));
96    bound_to.extend((0..la).map(|_| 1));
97    arities.push(la);
98
99    // "body arity": number of columns in `body`; the join we are building.
100    let mut ba = la;
101
102    // For each LEFT JOIN, there is a `right` input and an `on` constraint.
103    // We want to decorrelate them, failing if there are subqueries because omg no,
104    // and then check to see if the decorrelated `on` equates RHS columns with values
105    // in one prior input. If so; bring those values into the mix, and bind that as
106    // the value of the `Let` binding.
107    for (index, (right, on)) in rights.into_iter().rev().enumerate() {
108        // Correlated right expressions are handled in a different branch than standard
109        // outer join lowering, and I don't know what they mean. Fail conservatively.
110        if right.is_correlated() {
111            tracing::debug!(case = 2, index, "attempt_left_join_magic");
112            inc_metrics("voj_2");
113            return Ok(None);
114        }
115
116        // Decorrelate `right`.
117        let right_col_map = col_map.enter_scope(0);
118        let right = right
119            .clone()
120            .map(vec![HirScalarExpr::literal_true()]) // add a bit to mark "real" rows.
121            .applied_to(id_gen, get_outer.clone(), &right_col_map, cte_map, context)?;
122        let full_right_typ = right.typ();
123        let rt = full_right_typ
124            .column_types
125            .iter()
126            .skip(oa)
127            .cloned()
128            .collect_vec();
129        let ra = rt.len() - 1; // don't count the new column
130
131        let mut right_type = full_right_typ;
132        // Create a binding for `right`, unadulterated.
133        let id = LocalId::new(id_gen.allocate_id());
134        let get_right = MirRelationExpr::local_get(id, right_type.clone());
135        // Create a binding for the augmented right, which we will form here but use before we do.
136        // We want the join to be based off of the augmented relation, but we don't yet know how
137        // to augment it until we decorrelate `on`. So, we use a `Get` binding that we backfill.
138        for column in right_type.column_types.iter_mut() {
139            column.nullable = true;
140        }
141        right_type.keys.clear();
142        let aug_id = LocalId::new(id_gen.allocate_id());
143        let aug_right = MirRelationExpr::local_get(aug_id, right_type);
144
145        bindings.push((id, get_right.clone(), right));
146        bound_to.extend((0..ra).map(|_| 2 + index));
147        arities.push(ra);
148
149        // Cartesian join but equating the outer columns.
150        let mut product = MirRelationExpr::join(
151            vec![body, aug_right.clone()],
152            (0..oa).map(|i| vec![(0, i), (1, i)]).collect(),
153        )
154        // ... remove the second copy of the outer columns.
155        .project(
156            (0..(oa + ba))
157                .chain((oa + ba + oa)..(oa + ba + oa + ra + 1)) // include new column
158                .collect(),
159        );
160
161        // Decorrelate and lower the `on` clause.
162        let on = on
163            .clone()
164            .applied_to(id_gen, col_map, cte_map, &mut product, &None, context)?;
165
166        // if `on` added any new columns, .. no clue what to do.
167        // Return with failure, to avoid any confusion.
168        if product.arity() > oa + ba + ra + 1 {
169            tracing::debug!(case = 3, index, "attempt_left_join_magic");
170            inc_metrics("voj_3");
171            return Ok(None);
172        }
173
174        // If `on` equates columns in `right` with columns in some input,
175        // not just "any columns in `body`" but some single specific input,
176        // then we can fish out values from that input. If it equates values
177        // across multiple inputs, we would need to fish out valid tuples and
178        // no idea how we would get those w/o doing a join or a cartesian product.
179        let equations = if let Some(list) = decompose_equations(&on) {
180            list
181        } else {
182            tracing::debug!(case = 4, index, "attempt_left_join_magic");
183            inc_metrics("voj_4");
184            return Ok(None);
185        };
186
187        // We now need to see if all left columns exist in some input relation,
188        // and that all right columns are actually in the right relation. Idk.
189        // Left columns less than `oa` do not bind to an input, as they are for
190        // columns present in all inputs.
191        let mut bound_input = None;
192        for (left, right) in equations.iter().cloned() {
193            // If the right reference is not actually to `right`, bail out.
194            if right < oa + ba {
195                tracing::debug!(case = 5, index, "attempt_left_join_magic");
196                inc_metrics("voj_5");
197                return Ok(None);
198            }
199            // Only columns not from the outer scope introduce bindings.
200            if left >= oa {
201                if let Some(bound) = bound_input {
202                    // If left references come from different inputs, bail out.
203                    if bound_to[left] != bound {
204                        tracing::debug!(case = 6, index, "attempt_left_join_magic");
205                        inc_metrics("voj_6");
206                        return Ok(None);
207                    }
208                }
209                bound_input = Some(bound_to[left]);
210            }
211        }
212
213        if let Some(bound) = bound_input {
214            // This is great news; we have an input `bound` that we can augment,
215            // and just need to pull those values in to the definition of `right`.
216
217            // Add up prior arities, to learn what to subtract from left references.
218            // Don't subtract anything from left references less than `oa`!
219            let offset: usize = arities[0..bound].iter().sum();
220
221            // We now want to grab the `Get` for both left and right relations,
222            // which we will project to get distinct values, then difference and
223            // threshold to find those present in left but missing in right.
224            let get_left = &bindings[bound - 1].1;
225            // Set up a type for the all-nulls row we need to introduce.
226            let mut left_typ = get_left.typ();
227            for col in left_typ.column_types.iter_mut() {
228                col.nullable = true;
229            }
230            left_typ.keys.clear();
231            // `get_right` is already bound.
232
233            // Augment left_vals an all `Null` row, so that any null values
234            // match with nulls, and compute the distinct join keys in the
235            // resulting union.
236            let left_vals = MirRelationExpr::union(
237                get_left.clone(),
238                MirRelationExpr::Constant {
239                    rows: Ok(vec![(
240                        mz_repr::Row::pack(
241                            std::iter::repeat(mz_repr::Datum::Null).take(left_typ.arity()),
242                        ),
243                        Diff::ONE,
244                    )]),
245                    typ: left_typ,
246                },
247            )
248            .project(
249                equations
250                    .iter()
251                    .map(|(l, _)| if l < &oa { *l } else { l - offset })
252                    .collect::<Vec<_>>(),
253            )
254            .distinct();
255
256            // Compute the non-Null join keys on the right side. We skip the
257            // distinct because the eventual `threshold` between `left_vals` and
258            // `right_vals` protects us.
259            let right_vals = get_right
260                .clone()
261                // The #c1 IS NOT NULL AND ... AND #cn IS NOT NULL filter
262                // ensures that we won't remove the all `Null` row in the
263                // eventual `threshold` call.
264                .filter(
265                    equations
266                        .iter()
267                        .map(|(_, r)| MirScalarExpr::column(r - oa - ba).call_is_null().not()),
268                )
269                // Retain only the keys referenced on the right side of the LEFT
270                // JOIN equations.
271                .project(
272                    equations
273                        .iter()
274                        .map(|(_, r)| r - oa - ba)
275                        .collect::<Vec<_>>(),
276                );
277
278            // Now we need to permute them into place, and leave `Datum::Null` values behind.
279            let additions = MirRelationExpr::union(right_vals.negate(), left_vals)
280                .threshold()
281                .map(
282                    // Append nulls for all get_right columns, including the
283                    // extra column at the end that is used to differentiate between
284                    // augmented and original columns in the aug_value.
285                    rt.iter()
286                        .map(|t| MirScalarExpr::literal_null(t.scalar_type.clone()))
287                        .collect::<Vec<_>>(),
288                )
289                .project({
290                    // By default, we'll place post-pended nulls in each location.
291                    // We will overwrite this with instructions to find augmenting values.
292
293                    // Start with a projection that retains the last |rt|
294                    // columns corresponding to the NULLs from the above
295                    // .map(...) call.
296                    let mut projection =
297                        (equations.len()..equations.len() + rt.len()).collect::<Vec<_>>();
298                    // Replace NULLs columns corresponding to rhs columns
299                    // referenced in an ON equation with the actual rhs value
300                    // (located at `index`).
301                    for (index, (_, right)) in equations.iter().enumerate() {
302                        projection[*right - oa - ba] = index;
303                    }
304
305                    projection
306                });
307
308            // This is where we should add a boolean column to indicate that the row is augmented,
309            // so that after the join is done we can overwrite all values for `right` with null values.
310            // This is a quirk of how outer joins work: the matched columns are left as null.
311
312            // TODO(aalexandrov): if we never see an error from this we can
313            // 1. Use `get_right` instead of `bindings[index + 1].1.clone()`.
314            // 2. Simplify bindings to use tuples instead of triples.
315            soft_assert_eq_or_log!(&bindings[index + 1].1, &get_right);
316
317            let aug_value = MirRelationExpr::union(
318                bindings[index + 1]
319                    .1
320                    .clone()
321                    // The #c1 IS NOT NULL AND ... AND #cn IS NOT NULL filter
322                    // ensures that the `Null` keys appearing on the left side
323                    // can only match the all `Null` row from additions in the
324                    // eventual `product.filter(...)` call.
325                    .filter(
326                        equations
327                            .iter()
328                            .map(|(_, r)| MirScalarExpr::column(r - oa - ba).call_is_null().not()),
329                    ),
330                additions,
331            );
332
333            // Record the binding we'll need to make for `aug_id`.
334            augmented.push((aug_id, aug_right, aug_value));
335
336            // Update `body` to reflect the product, filtered by `on`.
337            body = product.filter(recompose_equations(equations));
338
339            body = body
340                // Update `body` so that each new column consults its final
341                // column, and if null sets all right columns to null.
342                .map(
343                    (oa + ba..oa + ba + ra)
344                        .map(|col| MirScalarExpr::If {
345                            cond: Box::new(MirScalarExpr::Column(oa + ba + ra).call_is_null()),
346                            then: Box::new(MirScalarExpr::literal_null(
347                                rt[col - (oa + ba)].scalar_type.clone(),
348                            )),
349                            els: Box::new(MirScalarExpr::Column(col)),
350                        })
351                        .collect(),
352                )
353                // Replace the original |ra + 1| columns with the |ra| columns
354                // produced by the above map(...) call.
355                .project(
356                    (0..oa + ba)
357                        .chain(oa + ba + ra + 1..oa + ba + ra + 1 + ra)
358                        .collect(),
359                );
360
361            ba += ra;
362
363            assert_eq!(oa + ba, body.arity());
364        } else {
365            tracing::debug!(case = 7, index, "attempt_left_join_magic");
366            inc_metrics("voj_7");
367            return Ok(None);
368        }
369    }
370
371    // If we've gotten this for, we've populated `bindings` with various let bindings
372    // we must now create, all wrapped around `body`.
373    while let Some((id, _get, value)) = augmented.pop() {
374        body = MirRelationExpr::Let {
375            id,
376            value: Box::new(value),
377            body: Box::new(body),
378        };
379    }
380    while let Some((id, _get, value)) = bindings.pop() {
381        body = MirRelationExpr::Let {
382            id,
383            value: Box::new(value),
384            body: Box::new(body),
385        };
386    }
387
388    tracing::debug!(case = 0, "attempt_left_join_magic");
389    inc_metrics("voj_0");
390    Ok(Some(body))
391}
392
393use mz_expr::{BinaryFunc, VariadicFunc};
394
395/// If `predicate` can be decomposed as any number of `col(x) = col(y)` expressions anded together, return them.
396fn decompose_equations(predicate: &MirScalarExpr) -> Option<Vec<(usize, usize)>> {
397    let mut equations = Vec::new();
398
399    let mut todo = vec![predicate];
400    while let Some(expr) = todo.pop() {
401        match expr {
402            MirScalarExpr::CallVariadic {
403                func: VariadicFunc::And,
404                exprs,
405            } => {
406                todo.extend(exprs.iter());
407            }
408            MirScalarExpr::CallBinary {
409                func: BinaryFunc::Eq,
410                expr1,
411                expr2,
412            } => {
413                if let (MirScalarExpr::Column(c1), MirScalarExpr::Column(c2)) = (&**expr1, &**expr2)
414                {
415                    if c1 < c2 {
416                        equations.push((*c1, *c2));
417                    } else {
418                        equations.push((*c2, *c1));
419                    }
420                } else {
421                    return None;
422                }
423            }
424            e if e.is_literal_true() => (), // `USING(c1,...,cN)` translates to `true && c1 = c1 ... cN = cN`.
425            _ => return None,
426        }
427    }
428
429    // Remove duplicates
430    equations.sort();
431    equations.dedup();
432
433    // Ensure that every rhs column c2 appears only once. Otherwise, we have at
434    // least two lhs columns c1 and c1' that are rendered equal by the same c2
435    // column. The VOJ lowering will then produce a plan that will incorrectly
436    // push down a local filter c1 = c1' to the lhs (see database-issues#7892).
437    if equations.iter().duplicates_by(|(_, c)| c).next().is_some() {
438        return None;
439    }
440
441    Some(equations)
442}
443
444/// Turns column equation into idiomatic Rust equation, where nulls equate.
445fn recompose_equations(pairs: Vec<(usize, usize)>) -> Vec<MirScalarExpr> {
446    pairs
447        .iter()
448        .map(|(x, y)| MirScalarExpr::CallVariadic {
449            func: VariadicFunc::Or,
450            exprs: vec![
451                MirScalarExpr::CallBinary {
452                    func: BinaryFunc::Eq,
453                    expr1: Box::new(MirScalarExpr::Column(*x)),
454                    expr2: Box::new(MirScalarExpr::Column(*y)),
455                },
456                MirScalarExpr::CallVariadic {
457                    func: VariadicFunc::And,
458                    exprs: vec![
459                        MirScalarExpr::Column(*x).call_is_null(),
460                        MirScalarExpr::Column(*y).call_is_null(),
461                    ],
462                },
463            ],
464        })
465        .collect()
466}