Skip to main content

mz_transform/
semijoin_idempotence.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Remove semijoins that are applied multiple times to no further effect.
11//!
12//! Mechanically, this transform looks for instances of `A join B` and replaces
13//! `B` with a simpler `C`. It does this in the restricted setting that each `join`
14//! would be a "semijoin": a multiplicity preserving restriction.
15//!
16//! The approach we use here is to restrict our attention to cases where
17//!
18//! 1. `A` is a potentially filtered instance of some `Get{id}`,
19//! 2. `A join B` equate columns of `A` to all columns of `B`,
20//! 3. The multiplicity of any record in `B` is at most one.
21//! 4. The values in these records are exactly `Get{id} join C`.
22//!
23//! We find a candidate `C` by descending `B` looking for another semijoin between
24//! `Get{id}` and some other collection `D` on the same columns as `A` means to join `B`.
25//! Should we find such, allowing arbitrary filters of `Get{id}` on the equated columns,
26//! which we will transfer to the columns of `D` thereby forming `C`.
27
28use itertools::Itertools;
29use mz_repr::ReprRelationType;
30use std::collections::BTreeMap;
31
32use mz_expr::{
33    Columns, Id, JoinInputMapper, LocalId, MirRelationExpr, MirScalarExpr, RECURSION_LIMIT,
34};
35use mz_ore::id_gen::IdGen;
36use mz_ore::stack::{CheckedRecursion, RecursionGuard};
37
38use crate::TransformCtx;
39
40/// Remove redundant semijoin operators
41#[derive(Debug)]
42pub struct SemijoinIdempotence {
43    recursion_guard: RecursionGuard,
44}
45
46impl Default for SemijoinIdempotence {
47    fn default() -> SemijoinIdempotence {
48        SemijoinIdempotence {
49            recursion_guard: RecursionGuard::with_limit(RECURSION_LIMIT),
50        }
51    }
52}
53
54impl CheckedRecursion for SemijoinIdempotence {
55    fn recursion_guard(&self) -> &RecursionGuard {
56        &self.recursion_guard
57    }
58}
59
60impl crate::Transform for SemijoinIdempotence {
61    fn name(&self) -> &'static str {
62        "SemijoinIdempotence"
63    }
64
65    #[mz_ore::instrument(
66        target = "optimizer",
67        level = "debug",
68        fields(path.segment = "semijoin_idempotence")
69    )]
70    fn actually_perform_transform(
71        &self,
72        relation: &mut MirRelationExpr,
73        _: &mut TransformCtx,
74    ) -> Result<(), crate::TransformError> {
75        // We need to call `renumber_bindings` because we will call
76        // `MirRelationExpr::collect_expirations`, which relies on this invariant.
77        crate::normalize_lets::renumber_bindings(relation, &mut IdGen::default())?;
78
79        let mut let_replacements = BTreeMap::<LocalId, Vec<Replacement>>::new();
80        let mut gets_behind_gets = BTreeMap::<LocalId, Vec<(Id, Vec<MirScalarExpr>)>>::new();
81        self.action(relation, &mut let_replacements, &mut gets_behind_gets)?;
82
83        mz_repr::explain::trace_plan(&*relation);
84        Ok(())
85    }
86}
87
88impl SemijoinIdempotence {
89    /// * `let_replacements` - `Replacement`s offered up by CTEs.
90    /// * `gets_behind_gets` - The result of `as_filtered_get` called on CTEs.
91    fn action(
92        &self,
93        expr: &mut MirRelationExpr,
94        let_replacements: &mut BTreeMap<LocalId, Vec<Replacement>>,
95        gets_behind_gets: &mut BTreeMap<LocalId, Vec<(Id, Vec<MirScalarExpr>)>>,
96    ) -> Result<(), crate::TransformError> {
97        // At each node, either gather info about Let bindings or attempt to simplify a join.
98        self.checked_recur(move |_| {
99            match expr {
100                MirRelationExpr::Let { id, value, body } => {
101                    let_replacements.insert(
102                        *id,
103                        list_replacements(&*value, let_replacements, gets_behind_gets),
104                    );
105                    gets_behind_gets.insert(*id, as_filtered_get(value, gets_behind_gets));
106                    self.action(value, let_replacements, gets_behind_gets)?;
107                    self.action(body, let_replacements, gets_behind_gets)?;
108                    // No need to do expirations here, as there is only one CTE (and it can't be
109                    // recursive).
110                }
111                MirRelationExpr::LetRec {
112                    ids,
113                    values,
114                    limits: _,
115                    body,
116                } => {
117                    // Expirations. See comments on `collect_expirations` and `do_expirations`.
118                    // Note that `expirations` is local to one `LetRec`, because a `LetRec` can't
119                    // reference something that is defined in an inner `LetRec`, so a definition in
120                    // an inner `LetRec` can't expire something from an outer `LetRec`.
121                    let mut expirations = BTreeMap::new();
122                    for (id, value) in ids.iter().zip_eq(values.iter_mut()) {
123                        // 1. Recursive call. This has to be before 2. to avoid problems when a
124                        // binding refers to itself.
125                        self.action(value, let_replacements, gets_behind_gets)?;
126
127                        // 2. Gather info from the `value` for use in later bindings and the body.
128                        let replacements_from_value =
129                            list_replacements(&*value, let_replacements, gets_behind_gets);
130                        let_replacements.insert(*id, replacements_from_value.clone());
131                        let value_as_filtered_gets = as_filtered_get(value, gets_behind_gets);
132                        gets_behind_gets.insert(*id, value_as_filtered_gets.clone());
133
134                        // 3. Collect expirations.
135                        for replacement in replacements_from_value {
136                            MirRelationExpr::collect_expirations(
137                                *id,
138                                &replacement.replacement,
139                                &mut expirations,
140                            );
141                        }
142                        for referenced_id in
143                            value_as_filtered_gets
144                                .iter()
145                                .filter_map(|(id, _filter)| match id {
146                                    Id::Local(lid) => Some(lid),
147                                    Id::Global(_) => None,
148                                })
149                        {
150                            if referenced_id >= id {
151                                expirations
152                                    .entry(*referenced_id)
153                                    .or_insert_with(Vec::new)
154                                    .push(*id);
155                            }
156                        }
157
158                        // 4. Perform expirations.
159                        MirRelationExpr::do_expirations(*id, &mut expirations, let_replacements);
160                        MirRelationExpr::do_expirations(*id, &mut expirations, gets_behind_gets);
161                    }
162                    self.action(body, let_replacements, gets_behind_gets)?;
163                }
164                MirRelationExpr::Join {
165                    inputs,
166                    equivalences,
167                    implementation,
168                    ..
169                } => {
170                    attempt_join_simplification(
171                        inputs,
172                        equivalences,
173                        implementation,
174                        let_replacements,
175                        gets_behind_gets,
176                    );
177                    for input in inputs {
178                        self.action(input, let_replacements, gets_behind_gets)?;
179                    }
180                }
181                _ => {
182                    for child in expr.children_mut() {
183                        self.action(child, let_replacements, gets_behind_gets)?;
184                    }
185                }
186            }
187            Ok::<(), crate::TransformError>(())
188        })
189    }
190}
191
192/// Attempt to simplify the join using local information and let bindings.
193fn attempt_join_simplification(
194    inputs: &mut [MirRelationExpr],
195    equivalences: &Vec<Vec<MirScalarExpr>>,
196    implementation: &mut mz_expr::JoinImplementation,
197    let_replacements: &BTreeMap<LocalId, Vec<Replacement>>,
198    gets_behind_gets: &BTreeMap<LocalId, Vec<(Id, Vec<MirScalarExpr>)>>,
199) {
200    // Useful join manipulation helper.
201    let input_mapper = JoinInputMapper::new(inputs);
202
203    if let Some((ltr, rtl)) = semijoin_bijection(inputs, equivalences) {
204        // If semijoin_bijection returns `Some(...)`, then `inputs.len() == 2`.
205        assert_eq!(inputs.len(), 2);
206
207        // Collect the `Get` identifiers each input might present as.
208        let ids0 = as_filtered_get(&inputs[0], gets_behind_gets)
209            .iter()
210            .map(|(id, _)| *id)
211            .collect::<Vec<_>>();
212        let ids1 = as_filtered_get(&inputs[1], gets_behind_gets)
213            .iter()
214            .map(|(id, _)| *id)
215            .collect::<Vec<_>>();
216
217        // Record the types of the inputs, for use in both loops below.
218        let typ0 = inputs[0].typ();
219        let typ1 = inputs[1].typ();
220
221        // Consider replacing the second input for the benefit of the first.
222        if distinct_on_keys_of(&typ1, &rtl) && input_mapper.input_arity(1) == equivalences.len() {
223            for mut candidate in list_replacements(&inputs[1], let_replacements, gets_behind_gets) {
224                if ids0.contains(&candidate.id) {
225                    if let Some(permutation) = validate_replacement(&ltr, &mut candidate) {
226                        inputs[1] = candidate.replacement.project(permutation);
227                        *implementation = mz_expr::JoinImplementation::Unimplemented;
228
229                        // Take a moment to think about pushing down `IS NOT NULL` tests.
230                        // The pushdown is for the benefit of CSE on the `A` expressions,
231                        // in the not uncommon case of nullable foreign keys in outer joins.
232                        // TODO: Discover the transform that would not require this code.
233                        let mut is_not_nulls = Vec::new();
234                        for (col0, col1) in ltr.iter() {
235                            // We are using the pre-computed types; recomputing the types here
236                            // might alter nullability. As of 2025-01-09, Gábor has not found that
237                            // happening. But for the future, notice that this could be a source of
238                            // inaccurate or inconsistent nullability information.
239                            if !typ1.column_types[*col1].nullable
240                                && typ0.column_types[*col0].nullable
241                            {
242                                is_not_nulls.push(MirScalarExpr::column(*col0).call_is_null().not())
243                            }
244                        }
245                        if !is_not_nulls.is_empty() {
246                            // Canonicalize otherwise arbitrary predicate order.
247                            is_not_nulls.sort();
248                            inputs[0] = inputs[0].take_dangerous().filter(is_not_nulls);
249                        }
250
251                        // GTFO because things are now crazy.
252                        return;
253                    }
254                }
255            }
256        }
257        // Consider replacing the first input for the benefit of the second.
258        if distinct_on_keys_of(&typ0, &ltr) && input_mapper.input_arity(0) == equivalences.len() {
259            for mut candidate in list_replacements(&inputs[0], let_replacements, gets_behind_gets) {
260                if ids1.contains(&candidate.id) {
261                    if let Some(permutation) = validate_replacement(&rtl, &mut candidate) {
262                        inputs[0] = candidate.replacement.project(permutation);
263                        *implementation = mz_expr::JoinImplementation::Unimplemented;
264
265                        // Take a moment to think about pushing down `IS NOT NULL` tests.
266                        // The pushdown is for the benefit of CSE on the `A` expressions,
267                        // in the not uncommon case of nullable foreign keys in outer joins.
268                        // TODO: Discover the transform that would not require this code.
269                        let mut is_not_nulls = Vec::new();
270                        for (col1, col0) in rtl.iter() {
271                            if !typ0.column_types[*col0].nullable
272                                && typ1.column_types[*col1].nullable
273                            {
274                                is_not_nulls.push(MirScalarExpr::column(*col1).call_is_null().not())
275                            }
276                        }
277                        if !is_not_nulls.is_empty() {
278                            inputs[1] = inputs[1].take_dangerous().filter(is_not_nulls);
279                        }
280
281                        // GTFO because things are now crazy.
282                        return;
283                    }
284                }
285            }
286        }
287    }
288}
289
290/// Evaluates the viability of a `candidate` to drive the replacement at `semijoin`.
291///
292/// Returns a projection to apply to `candidate.replacement` if everything checks out.
293fn validate_replacement(
294    map: &BTreeMap<usize, usize>,
295    candidate: &mut Replacement,
296) -> Option<Vec<usize>> {
297    if candidate.columns.len() == map.len()
298        && candidate
299            .columns
300            .iter()
301            .all(|(c0, c1, _c2)| map.get(c0) == Some(c1))
302    {
303        candidate.columns.sort_by_key(|(_, c, _)| *c);
304        Some(
305            candidate
306                .columns
307                .iter()
308                .map(|(_, _, c)| *c)
309                .collect::<Vec<_>>(),
310        )
311    } else {
312        None
313    }
314}
315
316/// A restricted form of a semijoin idempotence information.
317///
318/// A `Replacement` may be offered up by any `MirRelationExpr`, meant to be `B` from above or similar,
319/// and indicates that the offered expression can be projected onto columns such that it then exactly equals
320/// a column projection of `Get{id} semijoin replacement`.
321///
322/// Specifically,
323/// the `columns` member lists indexes `(a, b, c)` where column `b` of the offering expression corresponds to
324/// columns `a` in `Get{id}` and `c` in `replacement`, and for which the semijoin requires `a = c`. The values
325/// of the projection of the offering expression onto the `b` indexes exactly equal the intersection of the
326/// projection of `Get{id}` onto the `a` indexes and the projection of `replacement` onto the `c` columns.
327#[derive(Clone, Debug)]
328struct Replacement {
329    id: Id,
330    columns: Vec<(usize, usize, usize)>,
331    replacement: MirRelationExpr,
332}
333
334/// Return a list of potential semijoin replacements for `expr`.
335///
336/// This method descends recursively, traversing `Get`, `Project`, `Reduce`, and `ArrangeBy` operators
337/// looking for a `Join` operator, at which point it defers to the `list_replacements_join` method.
338fn list_replacements(
339    expr: &MirRelationExpr,
340    let_replacements: &BTreeMap<LocalId, Vec<Replacement>>,
341    gets_behind_gets: &BTreeMap<LocalId, Vec<(Id, Vec<MirScalarExpr>)>>,
342) -> Vec<Replacement> {
343    let mut results = Vec::new();
344    match expr {
345        MirRelationExpr::Get {
346            id: Id::Local(lid), ..
347        } => {
348            // The `Get` may reference an `id` that offers semijoin replacements.
349            if let Some(replacements) = let_replacements.get(lid) {
350                results.extend(replacements.iter().cloned());
351            }
352        }
353        MirRelationExpr::Join {
354            inputs,
355            equivalences,
356            ..
357        } => {
358            results.extend(list_replacements_join(
359                inputs,
360                equivalences,
361                gets_behind_gets,
362            ));
363        }
364        MirRelationExpr::Project { input, outputs } => {
365            // If the columns are preserved by projection ..
366            results.extend(
367                list_replacements(input, let_replacements, gets_behind_gets)
368                    .into_iter()
369                    .filter_map(|mut replacement| {
370                        let new_cols = replacement
371                            .columns
372                            .iter()
373                            .filter_map(|(c0, c1, c2)| {
374                                outputs.iter().position(|o| o == c1).map(|c| (*c0, c, *c2))
375                            })
376                            .collect::<Vec<_>>();
377                        if new_cols.len() == replacement.columns.len() {
378                            replacement.columns = new_cols;
379                            Some(replacement)
380                        } else {
381                            None
382                        }
383                    }),
384            );
385        }
386        MirRelationExpr::Reduce {
387            input, group_key, ..
388        } => {
389            // If the columns are preserved by `group_key` ..
390            results.extend(
391                list_replacements(input, let_replacements, gets_behind_gets)
392                    .into_iter()
393                    .filter_map(|mut replacement| {
394                        let new_cols = replacement
395                            .columns
396                            .iter()
397                            .filter_map(|(c0, c1, c2)| {
398                                group_key
399                                    .iter()
400                                    .position(|o| o.as_column() == Some(*c1))
401                                    .map(|c| (*c0, c, *c2))
402                            })
403                            .collect::<Vec<_>>();
404                        if new_cols.len() == replacement.columns.len() {
405                            replacement.columns = new_cols;
406                            Some(replacement)
407                        } else {
408                            None
409                        }
410                    }),
411            );
412        }
413        MirRelationExpr::ArrangeBy { input, .. } => {
414            results.extend(list_replacements(input, let_replacements, gets_behind_gets));
415        }
416        _ => {}
417    }
418    results
419}
420
421/// Return a list of potential semijoin replacements for `expr`.
422fn list_replacements_join(
423    inputs: &[MirRelationExpr],
424    equivalences: &Vec<Vec<MirScalarExpr>>,
425    gets_behind_gets: &BTreeMap<LocalId, Vec<(Id, Vec<MirScalarExpr>)>>,
426) -> Vec<Replacement> {
427    // Result replacements.
428    let mut results = Vec::new();
429
430    // If we are a binary join whose equivalence classes equate columns in the two inputs.
431    if let Some((ltr, rtl)) = semijoin_bijection(inputs, equivalences) {
432        // Each unique key could be a semijoin candidate.
433        // We want to check that the join equivalences exactly match the key,
434        // and then transcribe the corresponding columns in the other input.
435        if distinct_on_keys_of(&inputs[1].typ(), &rtl) {
436            let columns = ltr
437                .iter()
438                .map(|(k0, k1)| (*k0, *k0, *k1))
439                .collect::<Vec<_>>();
440
441            for (id, mut predicates) in as_filtered_get(&inputs[0], gets_behind_gets) {
442                if predicates
443                    .iter()
444                    .all(|e| e.support().iter().all(|c| ltr.contains_key(c)))
445                {
446                    for predicate in predicates.iter_mut() {
447                        predicate.permute_map(&ltr);
448                    }
449
450                    let mut replacement = inputs[1].clone();
451                    if !predicates.is_empty() {
452                        replacement = replacement.filter(predicates.clone());
453                    }
454                    results.push(Replacement {
455                        id,
456                        columns: columns.clone(),
457                        replacement,
458                    })
459                }
460            }
461        }
462        // Each unique key could be a semijoin candidate.
463        // We want to check that the join equivalences exactly match the key,
464        // and then transcribe the corresponding columns in the other input.
465        if distinct_on_keys_of(&inputs[0].typ(), &ltr) {
466            let columns = ltr
467                .iter()
468                .map(|(k0, k1)| (*k1, *k0, *k0))
469                .collect::<Vec<_>>();
470
471            for (id, mut predicates) in as_filtered_get(&inputs[1], gets_behind_gets) {
472                if predicates
473                    .iter()
474                    .all(|e| e.support().iter().all(|c| rtl.contains_key(c)))
475                {
476                    for predicate in predicates.iter_mut() {
477                        predicate.permute_map(&rtl);
478                    }
479
480                    let mut replacement = inputs[0].clone();
481                    if !predicates.is_empty() {
482                        replacement = replacement.filter(predicates.clone());
483                    }
484                    results.push(Replacement {
485                        id,
486                        columns: columns.clone(),
487                        replacement,
488                    })
489                }
490            }
491        }
492    }
493
494    results
495}
496
497/// True iff some unique key of `typ` is contained in the keys of `map`.
498fn distinct_on_keys_of(typ: &ReprRelationType, map: &BTreeMap<usize, usize>) -> bool {
499    typ.keys
500        .iter()
501        .any(|key| key.iter().all(|k| map.contains_key(k)))
502}
503
504/// Attempts to interpret `expr` as filters applied to a `Get`.
505///
506/// Returns a list of such interpretations, potentially spanning `Let` bindings.
507fn as_filtered_get(
508    mut expr: &MirRelationExpr,
509    gets_behind_gets: &BTreeMap<LocalId, Vec<(Id, Vec<MirScalarExpr>)>>,
510) -> Vec<(Id, Vec<MirScalarExpr>)> {
511    let mut results = Vec::new();
512    while let MirRelationExpr::Filter { input, predicates } = expr {
513        results.extend(predicates.iter().cloned());
514        expr = &**input;
515    }
516    if let MirRelationExpr::Get { id, .. } = expr {
517        let mut output = Vec::new();
518        if let Id::Local(lid) = id {
519            if let Some(bound) = gets_behind_gets.get(lid) {
520                for (id, list) in bound.iter() {
521                    let mut predicates = list.clone();
522                    predicates.extend(results.iter().cloned());
523                    output.push((*id, predicates));
524                }
525            }
526        }
527        output.push((*id, results));
528        output
529    } else {
530        Vec::new()
531    }
532}
533
534/// Determines bijection between equated columns of a binary join.
535///
536/// Returns nothing if not a binary join, or if any equivalences are not of two opposing columns.
537/// Returned maps go from the column of the first input to those of the second, and vice versa.
538fn semijoin_bijection(
539    inputs: &[MirRelationExpr],
540    equivalences: &Vec<Vec<MirScalarExpr>>,
541) -> Option<(BTreeMap<usize, usize>, BTreeMap<usize, usize>)> {
542    // Useful join manipulation helper.
543    let input_mapper = JoinInputMapper::new(inputs);
544
545    // Pairs of equated columns localized to inputs 0 and 1.
546    let mut equiv_pairs = Vec::with_capacity(equivalences.len());
547
548    // Populate `equiv_pairs`, ideally finding exactly one pair for each equivalence class.
549    // TODO(mgree) !!! store the column names
550    for eq in equivalences.iter() {
551        if eq.len() == 2 {
552            // The equivalence class could reference the inputs in either order, or be some
553            // tangle of references (e.g. to both) that we want to avoid reacting to.
554            match (
555                input_mapper.single_input(&eq[0]),
556                input_mapper.single_input(&eq[1]),
557            ) {
558                (Some(0), Some(1)) => {
559                    let expr0 = input_mapper.map_expr_to_local(eq[0].clone());
560                    let expr1 = input_mapper.map_expr_to_local(eq[1].clone());
561                    if let (
562                        MirScalarExpr::Column(col0, _name0),
563                        MirScalarExpr::Column(col1, _name1),
564                    ) = (expr0, expr1)
565                    {
566                        equiv_pairs.push((col0, col1));
567                    }
568                }
569                (Some(1), Some(0)) => {
570                    let expr0 = input_mapper.map_expr_to_local(eq[1].clone());
571                    let expr1 = input_mapper.map_expr_to_local(eq[0].clone());
572                    if let (
573                        MirScalarExpr::Column(col0, _name0),
574                        MirScalarExpr::Column(col1, _name1),
575                    ) = (expr0, expr1)
576                    {
577                        equiv_pairs.push((col0, col1));
578                    }
579                }
580                _ => {}
581            }
582        }
583    }
584
585    if inputs.len() == 2 && equiv_pairs.len() == equivalences.len() {
586        let ltr = equiv_pairs.iter().cloned().collect();
587        let rtl = equiv_pairs.iter().map(|(c0, c1)| (*c1, *c0)).collect();
588
589        Some((ltr, rtl))
590    } else {
591        None
592    }
593}