mz_transform/
semijoin_idempotence.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Remove semijoins that are applied multiple times to no further effect.
11//!
12//! Mechanically, this transform looks for instances of `A join B` and replaces
13//! `B` with a simpler `C`. It does this in the restricted setting that each `join`
14//! would be a "semijoin": a multiplicity preserving restriction.
15//!
16//! The approach we use here is to restrict our attention to cases where
17//!
18//! 1. `A` is a potentially filtered instance of some `Get{id}`,
19//! 2. `A join B` equate columns of `A` to all columns of `B`,
20//! 3. The multiplicity of any record in `B` is at most one.
21//! 4. The values in these records are exactly `Get{id} join C`.
22//!
23//! We find a candidate `C` by descending `B` looking for another semijoin between
24//! `Get{id}` and some other collection `D` on the same columns as `A` means to join `B`.
25//! Should we find such, allowing arbitrary filters of `Get{id}` on the equated columns,
26//! which we will transfer to the columns of `D` thereby forming `C`.
27
28use itertools::Itertools;
29use mz_repr::RelationType;
30use std::collections::BTreeMap;
31
32use mz_expr::{Id, JoinInputMapper, LocalId, MirRelationExpr, MirScalarExpr, RECURSION_LIMIT};
33use mz_ore::id_gen::IdGen;
34use mz_ore::stack::{CheckedRecursion, RecursionGuard};
35
36use crate::TransformCtx;
37
38/// Remove redundant semijoin operators
39#[derive(Debug)]
40pub struct SemijoinIdempotence {
41    recursion_guard: RecursionGuard,
42}
43
44impl Default for SemijoinIdempotence {
45    fn default() -> SemijoinIdempotence {
46        SemijoinIdempotence {
47            recursion_guard: RecursionGuard::with_limit(RECURSION_LIMIT),
48        }
49    }
50}
51
52impl CheckedRecursion for SemijoinIdempotence {
53    fn recursion_guard(&self) -> &RecursionGuard {
54        &self.recursion_guard
55    }
56}
57
58impl crate::Transform for SemijoinIdempotence {
59    fn name(&self) -> &'static str {
60        "SemijoinIdempotence"
61    }
62
63    #[mz_ore::instrument(
64        target = "optimizer",
65        level = "debug",
66        fields(path.segment = "semijoin_idempotence")
67    )]
68    fn actually_perform_transform(
69        &self,
70        relation: &mut MirRelationExpr,
71        _: &mut TransformCtx,
72    ) -> Result<(), crate::TransformError> {
73        // We need to call `renumber_bindings` because we will call
74        // `MirRelationExpr::collect_expirations`, which relies on this invariant.
75        crate::normalize_lets::renumber_bindings(relation, &mut IdGen::default())?;
76
77        let mut let_replacements = BTreeMap::<LocalId, Vec<Replacement>>::new();
78        let mut gets_behind_gets = BTreeMap::<LocalId, Vec<(Id, Vec<MirScalarExpr>)>>::new();
79        self.action(relation, &mut let_replacements, &mut gets_behind_gets)?;
80
81        mz_repr::explain::trace_plan(&*relation);
82        Ok(())
83    }
84}
85
86impl SemijoinIdempotence {
87    /// * `let_replacements` - `Replacement`s offered up by CTEs.
88    /// * `gets_behind_gets` - The result of `as_filtered_get` called on CTEs.
89    fn action(
90        &self,
91        expr: &mut MirRelationExpr,
92        let_replacements: &mut BTreeMap<LocalId, Vec<Replacement>>,
93        gets_behind_gets: &mut BTreeMap<LocalId, Vec<(Id, Vec<MirScalarExpr>)>>,
94    ) -> Result<(), crate::TransformError> {
95        // At each node, either gather info about Let bindings or attempt to simplify a join.
96        self.checked_recur(move |_| {
97            match expr {
98                MirRelationExpr::Let { id, value, body } => {
99                    let_replacements.insert(
100                        *id,
101                        list_replacements(&*value, let_replacements, gets_behind_gets),
102                    );
103                    gets_behind_gets.insert(*id, as_filtered_get(value, gets_behind_gets));
104                    self.action(value, let_replacements, gets_behind_gets)?;
105                    self.action(body, let_replacements, gets_behind_gets)?;
106                    // No need to do expirations here, as there is only one CTE (and it can't be
107                    // recursive).
108                }
109                MirRelationExpr::LetRec {
110                    ids,
111                    values,
112                    limits: _,
113                    body,
114                } => {
115                    // Expirations. See comments on `collect_expirations` and `do_expirations`.
116                    // Note that `expirations` is local to one `LetRec`, because a `LetRec` can't
117                    // reference something that is defined in an inner `LetRec`, so a definition in
118                    // an inner `LetRec` can't expire something from an outer `LetRec`.
119                    let mut expirations = BTreeMap::new();
120                    for (id, value) in ids.iter().zip_eq(values.iter_mut()) {
121                        // 1. Recursive call. This has to be before 2. to avoid problems when a
122                        // binding refers to itself.
123                        self.action(value, let_replacements, gets_behind_gets)?;
124
125                        // 2. Gather info from the `value` for use in later bindings and the body.
126                        let replacements_from_value =
127                            list_replacements(&*value, let_replacements, gets_behind_gets);
128                        let_replacements.insert(*id, replacements_from_value.clone());
129                        let value_as_filtered_gets = as_filtered_get(value, gets_behind_gets);
130                        gets_behind_gets.insert(*id, value_as_filtered_gets.clone());
131
132                        // 3. Collect expirations.
133                        for replacement in replacements_from_value {
134                            MirRelationExpr::collect_expirations(
135                                *id,
136                                &replacement.replacement,
137                                &mut expirations,
138                            );
139                        }
140                        for referenced_id in
141                            value_as_filtered_gets
142                                .iter()
143                                .filter_map(|(id, _filter)| match id {
144                                    Id::Local(lid) => Some(lid),
145                                    _ => None,
146                                })
147                        {
148                            if referenced_id >= id {
149                                expirations
150                                    .entry(*referenced_id)
151                                    .or_insert_with(Vec::new)
152                                    .push(*id);
153                            }
154                        }
155
156                        // 4. Perform expirations.
157                        MirRelationExpr::do_expirations(*id, &mut expirations, let_replacements);
158                        MirRelationExpr::do_expirations(*id, &mut expirations, gets_behind_gets);
159                    }
160                    self.action(body, let_replacements, gets_behind_gets)?;
161                }
162                MirRelationExpr::Join {
163                    inputs,
164                    equivalences,
165                    implementation,
166                    ..
167                } => {
168                    attempt_join_simplification(
169                        inputs,
170                        equivalences,
171                        implementation,
172                        let_replacements,
173                        gets_behind_gets,
174                    );
175                    for input in inputs {
176                        self.action(input, let_replacements, gets_behind_gets)?;
177                    }
178                }
179                _ => {
180                    for child in expr.children_mut() {
181                        self.action(child, let_replacements, gets_behind_gets)?;
182                    }
183                }
184            }
185            Ok::<(), crate::TransformError>(())
186        })
187    }
188}
189
190/// Attempt to simplify the join using local information and let bindings.
191fn attempt_join_simplification(
192    inputs: &mut [MirRelationExpr],
193    equivalences: &Vec<Vec<MirScalarExpr>>,
194    implementation: &mut mz_expr::JoinImplementation,
195    let_replacements: &BTreeMap<LocalId, Vec<Replacement>>,
196    gets_behind_gets: &BTreeMap<LocalId, Vec<(Id, Vec<MirScalarExpr>)>>,
197) {
198    // Useful join manipulation helper.
199    let input_mapper = JoinInputMapper::new(inputs);
200
201    if let Some((ltr, rtl)) = semijoin_bijection(inputs, equivalences) {
202        // If semijoin_bijection returns `Some(...)`, then `inputs.len() == 2`.
203        assert_eq!(inputs.len(), 2);
204
205        // Collect the `Get` identifiers each input might present as.
206        let ids0 = as_filtered_get(&inputs[0], gets_behind_gets)
207            .iter()
208            .map(|(id, _)| *id)
209            .collect::<Vec<_>>();
210        let ids1 = as_filtered_get(&inputs[1], gets_behind_gets)
211            .iter()
212            .map(|(id, _)| *id)
213            .collect::<Vec<_>>();
214
215        // Record the types of the inputs, for use in both loops below.
216        let typ0 = inputs[0].typ();
217        let typ1 = inputs[1].typ();
218
219        // Consider replacing the second input for the benefit of the first.
220        if distinct_on_keys_of(&typ1, &rtl) && input_mapper.input_arity(1) == equivalences.len() {
221            for mut candidate in list_replacements(&inputs[1], let_replacements, gets_behind_gets) {
222                if ids0.contains(&candidate.id) {
223                    if let Some(permutation) = validate_replacement(&ltr, &mut candidate) {
224                        inputs[1] = candidate.replacement.project(permutation);
225                        *implementation = mz_expr::JoinImplementation::Unimplemented;
226
227                        // Take a moment to think about pushing down `IS NOT NULL` tests.
228                        // The pushdown is for the benefit of CSE on the `A` expressions,
229                        // in the not uncommon case of nullable foreign keys in outer joins.
230                        // TODO: Discover the transform that would not require this code.
231                        let mut is_not_nulls = Vec::new();
232                        for (col0, col1) in ltr.iter() {
233                            // We are using the pre-computed types; recomputing the types here
234                            // might alter nullability. As of 2025-01-09, Gábor has not found that
235                            // happening. But for the future, notice that this could be a source of
236                            // inaccurate or inconsistent nullability information.
237                            if !typ1.column_types[*col1].nullable
238                                && typ0.column_types[*col0].nullable
239                            {
240                                is_not_nulls.push(MirScalarExpr::Column(*col0).call_is_null().not())
241                            }
242                        }
243                        if !is_not_nulls.is_empty() {
244                            // Canonicalize otherwise arbitrary predicate order.
245                            is_not_nulls.sort();
246                            inputs[0] = inputs[0].take_dangerous().filter(is_not_nulls);
247                        }
248
249                        // GTFO because things are now crazy.
250                        return;
251                    }
252                }
253            }
254        }
255        // Consider replacing the first input for the benefit of the second.
256        if distinct_on_keys_of(&typ0, &ltr) && input_mapper.input_arity(0) == equivalences.len() {
257            for mut candidate in list_replacements(&inputs[0], let_replacements, gets_behind_gets) {
258                if ids1.contains(&candidate.id) {
259                    if let Some(permutation) = validate_replacement(&rtl, &mut candidate) {
260                        inputs[0] = candidate.replacement.project(permutation);
261                        *implementation = mz_expr::JoinImplementation::Unimplemented;
262
263                        // Take a moment to think about pushing down `IS NOT NULL` tests.
264                        // The pushdown is for the benefit of CSE on the `A` expressions,
265                        // in the not uncommon case of nullable foreign keys in outer joins.
266                        // TODO: Discover the transform that would not require this code.
267                        let mut is_not_nulls = Vec::new();
268                        for (col1, col0) in rtl.iter() {
269                            if !typ0.column_types[*col0].nullable
270                                && typ1.column_types[*col1].nullable
271                            {
272                                is_not_nulls.push(MirScalarExpr::Column(*col1).call_is_null().not())
273                            }
274                        }
275                        if !is_not_nulls.is_empty() {
276                            inputs[1] = inputs[1].take_dangerous().filter(is_not_nulls);
277                        }
278
279                        // GTFO because things are now crazy.
280                        return;
281                    }
282                }
283            }
284        }
285    }
286}
287
288/// Evaluates the viability of a `candidate` to drive the replacement at `semijoin`.
289///
290/// Returns a projection to apply to `candidate.replacement` if everything checks out.
291fn validate_replacement(
292    map: &BTreeMap<usize, usize>,
293    candidate: &mut Replacement,
294) -> Option<Vec<usize>> {
295    if candidate.columns.len() == map.len()
296        && candidate
297            .columns
298            .iter()
299            .all(|(c0, c1, _c2)| map.get(c0) == Some(c1))
300    {
301        candidate.columns.sort_by_key(|(_, c, _)| *c);
302        Some(
303            candidate
304                .columns
305                .iter()
306                .map(|(_, _, c)| *c)
307                .collect::<Vec<_>>(),
308        )
309    } else {
310        None
311    }
312}
313
314/// A restricted form of a semijoin idempotence information.
315///
316/// A `Replacement` may be offered up by any `MirRelationExpr`, meant to be `B` from above or similar,
317/// and indicates that the offered expression can be projected onto columns such that it then exactly equals
318/// a column projection of `Get{id} semijoin replacement`.
319///
320/// Specifically,
321/// the `columns` member lists indexes `(a, b, c)` where column `b` of the offering expression corresponds to
322/// columns `a` in `Get{id}` and `c` in `replacement`, and for which the semijoin requires `a = c`. The values
323/// of the projection of the offering expression onto the `b` indexes exactly equal the intersection of the
324/// projection of `Get{id}` onto the `a` indexes and the projection of `replacement` onto the `c` columns.
325#[derive(Clone, Debug)]
326struct Replacement {
327    id: Id,
328    columns: Vec<(usize, usize, usize)>,
329    replacement: MirRelationExpr,
330}
331
332/// Return a list of potential semijoin replacements for `expr`.
333///
334/// This method descends recursively, traversing `Get`, `Project`, `Reduce`, and `ArrangeBy` operators
335/// looking for a `Join` operator, at which point it defers to the `list_replacements_join` method.
336fn list_replacements(
337    expr: &MirRelationExpr,
338    let_replacements: &BTreeMap<LocalId, Vec<Replacement>>,
339    gets_behind_gets: &BTreeMap<LocalId, Vec<(Id, Vec<MirScalarExpr>)>>,
340) -> Vec<Replacement> {
341    let mut results = Vec::new();
342    match expr {
343        MirRelationExpr::Get {
344            id: Id::Local(lid), ..
345        } => {
346            // The `Get` may reference an `id` that offers semijoin replacements.
347            if let Some(replacements) = let_replacements.get(lid) {
348                results.extend(replacements.iter().cloned());
349            }
350        }
351        MirRelationExpr::Join {
352            inputs,
353            equivalences,
354            ..
355        } => {
356            results.extend(list_replacements_join(
357                inputs,
358                equivalences,
359                gets_behind_gets,
360            ));
361        }
362        MirRelationExpr::Project { input, outputs } => {
363            // If the columns are preserved by projection ..
364            results.extend(
365                list_replacements(input, let_replacements, gets_behind_gets)
366                    .into_iter()
367                    .filter_map(|mut replacement| {
368                        let new_cols = replacement
369                            .columns
370                            .iter()
371                            .filter_map(|(c0, c1, c2)| {
372                                outputs.iter().position(|o| o == c1).map(|c| (*c0, c, *c2))
373                            })
374                            .collect::<Vec<_>>();
375                        if new_cols.len() == replacement.columns.len() {
376                            replacement.columns = new_cols;
377                            Some(replacement)
378                        } else {
379                            None
380                        }
381                    }),
382            );
383        }
384        MirRelationExpr::Reduce {
385            input, group_key, ..
386        } => {
387            // If the columns are preserved by `group_key` ..
388            results.extend(
389                list_replacements(input, let_replacements, gets_behind_gets)
390                    .into_iter()
391                    .filter_map(|mut replacement| {
392                        let new_cols = replacement
393                            .columns
394                            .iter()
395                            .filter_map(|(c0, c1, c2)| {
396                                group_key
397                                    .iter()
398                                    .position(|o| o == &MirScalarExpr::Column(*c1))
399                                    .map(|c| (*c0, c, *c2))
400                            })
401                            .collect::<Vec<_>>();
402                        if new_cols.len() == replacement.columns.len() {
403                            replacement.columns = new_cols;
404                            Some(replacement)
405                        } else {
406                            None
407                        }
408                    }),
409            );
410        }
411        MirRelationExpr::ArrangeBy { input, .. } => {
412            results.extend(list_replacements(input, let_replacements, gets_behind_gets));
413        }
414        _ => {}
415    }
416    results
417}
418
419/// Return a list of potential semijoin replacements for `expr`.
420fn list_replacements_join(
421    inputs: &[MirRelationExpr],
422    equivalences: &Vec<Vec<MirScalarExpr>>,
423    gets_behind_gets: &BTreeMap<LocalId, Vec<(Id, Vec<MirScalarExpr>)>>,
424) -> Vec<Replacement> {
425    // Result replacements.
426    let mut results = Vec::new();
427
428    // If we are a binary join whose equivalence classes equate columns in the two inputs.
429    if let Some((ltr, rtl)) = semijoin_bijection(inputs, equivalences) {
430        // Each unique key could be a semijoin candidate.
431        // We want to check that the join equivalences exactly match the key,
432        // and then transcribe the corresponding columns in the other input.
433        if distinct_on_keys_of(&inputs[1].typ(), &rtl) {
434            let columns = ltr
435                .iter()
436                .map(|(k0, k1)| (*k0, *k0, *k1))
437                .collect::<Vec<_>>();
438
439            for (id, mut predicates) in as_filtered_get(&inputs[0], gets_behind_gets) {
440                if predicates
441                    .iter()
442                    .all(|e| e.support().iter().all(|c| ltr.contains_key(c)))
443                {
444                    for predicate in predicates.iter_mut() {
445                        predicate.permute_map(&ltr);
446                    }
447
448                    let mut replacement = inputs[1].clone();
449                    if !predicates.is_empty() {
450                        replacement = replacement.filter(predicates.clone());
451                    }
452                    results.push(Replacement {
453                        id,
454                        columns: columns.clone(),
455                        replacement,
456                    })
457                }
458            }
459        }
460        // Each unique key could be a semijoin candidate.
461        // We want to check that the join equivalences exactly match the key,
462        // and then transcribe the corresponding columns in the other input.
463        if distinct_on_keys_of(&inputs[0].typ(), &ltr) {
464            let columns = ltr
465                .iter()
466                .map(|(k0, k1)| (*k1, *k0, *k0))
467                .collect::<Vec<_>>();
468
469            for (id, mut predicates) in as_filtered_get(&inputs[1], gets_behind_gets) {
470                if predicates
471                    .iter()
472                    .all(|e| e.support().iter().all(|c| rtl.contains_key(c)))
473                {
474                    for predicate in predicates.iter_mut() {
475                        predicate.permute_map(&rtl);
476                    }
477
478                    let mut replacement = inputs[0].clone();
479                    if !predicates.is_empty() {
480                        replacement = replacement.filter(predicates.clone());
481                    }
482                    results.push(Replacement {
483                        id,
484                        columns: columns.clone(),
485                        replacement,
486                    })
487                }
488            }
489        }
490    }
491
492    results
493}
494
495/// True iff some unique key of `typ` is contained in the keys of `map`.
496fn distinct_on_keys_of(typ: &RelationType, map: &BTreeMap<usize, usize>) -> bool {
497    typ.keys
498        .iter()
499        .any(|key| key.iter().all(|k| map.contains_key(k)))
500}
501
502/// Attempts to interpret `expr` as filters applied to a `Get`.
503///
504/// Returns a list of such interpretations, potentially spanning `Let` bindings.
505fn as_filtered_get(
506    mut expr: &MirRelationExpr,
507    gets_behind_gets: &BTreeMap<LocalId, Vec<(Id, Vec<MirScalarExpr>)>>,
508) -> Vec<(Id, Vec<MirScalarExpr>)> {
509    let mut results = Vec::new();
510    while let MirRelationExpr::Filter { input, predicates } = expr {
511        results.extend(predicates.iter().cloned());
512        expr = &**input;
513    }
514    if let MirRelationExpr::Get { id, .. } = expr {
515        let mut output = Vec::new();
516        if let Id::Local(lid) = id {
517            if let Some(bound) = gets_behind_gets.get(lid) {
518                for (id, list) in bound.iter() {
519                    let mut predicates = list.clone();
520                    predicates.extend(results.iter().cloned());
521                    output.push((*id, predicates));
522                }
523            }
524        }
525        output.push((*id, results));
526        output
527    } else {
528        Vec::new()
529    }
530}
531
532/// Determines bijection between equated columns of a binary join.
533///
534/// Returns nothing if not a binary join, or if any equivalences are not of two opposing columns.
535/// Returned maps go from the column of the first input to those of the second, and vice versa.
536fn semijoin_bijection(
537    inputs: &[MirRelationExpr],
538    equivalences: &Vec<Vec<MirScalarExpr>>,
539) -> Option<(BTreeMap<usize, usize>, BTreeMap<usize, usize>)> {
540    // Useful join manipulation helper.
541    let input_mapper = JoinInputMapper::new(inputs);
542
543    // Pairs of equated columns localized to inputs 0 and 1.
544    let mut equiv_pairs = Vec::with_capacity(equivalences.len());
545
546    // Populate `equiv_pairs`, ideally finding exactly one pair for each equivalence class.
547    for eq in equivalences.iter() {
548        if eq.len() == 2 {
549            // The equivalence class could reference the inputs in either order, or be some
550            // tangle of references (e.g. to both) that we want to avoid reacting to.
551            match (
552                input_mapper.single_input(&eq[0]),
553                input_mapper.single_input(&eq[1]),
554            ) {
555                (Some(0), Some(1)) => {
556                    let expr0 = input_mapper.map_expr_to_local(eq[0].clone());
557                    let expr1 = input_mapper.map_expr_to_local(eq[1].clone());
558                    if let (MirScalarExpr::Column(col0), MirScalarExpr::Column(col1)) =
559                        (expr0, expr1)
560                    {
561                        equiv_pairs.push((col0, col1));
562                    }
563                }
564                (Some(1), Some(0)) => {
565                    let expr0 = input_mapper.map_expr_to_local(eq[1].clone());
566                    let expr1 = input_mapper.map_expr_to_local(eq[0].clone());
567                    if let (MirScalarExpr::Column(col0), MirScalarExpr::Column(col1)) =
568                        (expr0, expr1)
569                    {
570                        equiv_pairs.push((col0, col1));
571                    }
572                }
573                _ => {}
574            }
575        }
576    }
577
578    if inputs.len() == 2 && equiv_pairs.len() == equivalences.len() {
579        let ltr = equiv_pairs.iter().cloned().collect();
580        let rtl = equiv_pairs.iter().map(|(c0, c1)| (*c1, *c0)).collect();
581
582        Some((ltr, rtl))
583    } else {
584        None
585    }
586}