mz_transform/semijoin_idempotence.rs
1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Remove semijoins that are applied multiple times to no further effect.
11//!
12//! Mechanically, this transform looks for instances of `A join B` and replaces
13//! `B` with a simpler `C`. It does this in the restricted setting that each `join`
14//! would be a "semijoin": a multiplicity preserving restriction.
15//!
16//! The approach we use here is to restrict our attention to cases where
17//!
18//! 1. `A` is a potentially filtered instance of some `Get{id}`,
19//! 2. `A join B` equate columns of `A` to all columns of `B`,
20//! 3. The multiplicity of any record in `B` is at most one.
21//! 4. The values in these records are exactly `Get{id} join C`.
22//!
23//! We find a candidate `C` by descending `B` looking for another semijoin between
24//! `Get{id}` and some other collection `D` on the same columns as `A` means to join `B`.
25//! Should we find such, allowing arbitrary filters of `Get{id}` on the equated columns,
26//! which we will transfer to the columns of `D` thereby forming `C`.
27
28use itertools::Itertools;
29use mz_repr::ReprRelationType;
30use std::collections::BTreeMap;
31
32use mz_expr::{
33 Columns, Id, JoinInputMapper, LocalId, MirRelationExpr, MirScalarExpr, RECURSION_LIMIT,
34};
35use mz_ore::id_gen::IdGen;
36use mz_ore::stack::{CheckedRecursion, RecursionGuard};
37
38use crate::TransformCtx;
39
40/// Remove redundant semijoin operators
41#[derive(Debug)]
42pub struct SemijoinIdempotence {
43 recursion_guard: RecursionGuard,
44}
45
46impl Default for SemijoinIdempotence {
47 fn default() -> SemijoinIdempotence {
48 SemijoinIdempotence {
49 recursion_guard: RecursionGuard::with_limit(RECURSION_LIMIT),
50 }
51 }
52}
53
54impl CheckedRecursion for SemijoinIdempotence {
55 fn recursion_guard(&self) -> &RecursionGuard {
56 &self.recursion_guard
57 }
58}
59
60impl crate::Transform for SemijoinIdempotence {
61 fn name(&self) -> &'static str {
62 "SemijoinIdempotence"
63 }
64
65 #[mz_ore::instrument(
66 target = "optimizer",
67 level = "debug",
68 fields(path.segment = "semijoin_idempotence")
69 )]
70 fn actually_perform_transform(
71 &self,
72 relation: &mut MirRelationExpr,
73 _: &mut TransformCtx,
74 ) -> Result<(), crate::TransformError> {
75 // We need to call `renumber_bindings` because we will call
76 // `MirRelationExpr::collect_expirations`, which relies on this invariant.
77 crate::normalize_lets::renumber_bindings(relation, &mut IdGen::default())?;
78
79 let mut let_replacements = BTreeMap::<LocalId, Vec<Replacement>>::new();
80 let mut gets_behind_gets = BTreeMap::<LocalId, Vec<(Id, Vec<MirScalarExpr>)>>::new();
81 self.action(relation, &mut let_replacements, &mut gets_behind_gets)?;
82
83 mz_repr::explain::trace_plan(&*relation);
84 Ok(())
85 }
86}
87
88impl SemijoinIdempotence {
89 /// * `let_replacements` - `Replacement`s offered up by CTEs.
90 /// * `gets_behind_gets` - The result of `as_filtered_get` called on CTEs.
91 fn action(
92 &self,
93 expr: &mut MirRelationExpr,
94 let_replacements: &mut BTreeMap<LocalId, Vec<Replacement>>,
95 gets_behind_gets: &mut BTreeMap<LocalId, Vec<(Id, Vec<MirScalarExpr>)>>,
96 ) -> Result<(), crate::TransformError> {
97 // At each node, either gather info about Let bindings or attempt to simplify a join.
98 self.checked_recur(move |_| {
99 match expr {
100 MirRelationExpr::Let { id, value, body } => {
101 let_replacements.insert(
102 *id,
103 list_replacements(&*value, let_replacements, gets_behind_gets),
104 );
105 gets_behind_gets.insert(*id, as_filtered_get(value, gets_behind_gets));
106 self.action(value, let_replacements, gets_behind_gets)?;
107 self.action(body, let_replacements, gets_behind_gets)?;
108 // No need to do expirations here, as there is only one CTE (and it can't be
109 // recursive).
110 }
111 MirRelationExpr::LetRec {
112 ids,
113 values,
114 limits: _,
115 body,
116 } => {
117 // Expirations. See comments on `collect_expirations` and `do_expirations`.
118 // Note that `expirations` is local to one `LetRec`, because a `LetRec` can't
119 // reference something that is defined in an inner `LetRec`, so a definition in
120 // an inner `LetRec` can't expire something from an outer `LetRec`.
121 let mut expirations = BTreeMap::new();
122 for (id, value) in ids.iter().zip_eq(values.iter_mut()) {
123 // 1. Recursive call. This has to be before 2. to avoid problems when a
124 // binding refers to itself.
125 self.action(value, let_replacements, gets_behind_gets)?;
126
127 // 2. Gather info from the `value` for use in later bindings and the body.
128 let replacements_from_value =
129 list_replacements(&*value, let_replacements, gets_behind_gets);
130 let_replacements.insert(*id, replacements_from_value.clone());
131 let value_as_filtered_gets = as_filtered_get(value, gets_behind_gets);
132 gets_behind_gets.insert(*id, value_as_filtered_gets.clone());
133
134 // 3. Collect expirations.
135 for replacement in replacements_from_value {
136 MirRelationExpr::collect_expirations(
137 *id,
138 &replacement.replacement,
139 &mut expirations,
140 );
141 }
142 for referenced_id in
143 value_as_filtered_gets
144 .iter()
145 .filter_map(|(id, _filter)| match id {
146 Id::Local(lid) => Some(lid),
147 Id::Global(_) => None,
148 })
149 {
150 if referenced_id >= id {
151 expirations
152 .entry(*referenced_id)
153 .or_insert_with(Vec::new)
154 .push(*id);
155 }
156 }
157
158 // 4. Perform expirations.
159 MirRelationExpr::do_expirations(*id, &mut expirations, let_replacements);
160 MirRelationExpr::do_expirations(*id, &mut expirations, gets_behind_gets);
161 }
162 self.action(body, let_replacements, gets_behind_gets)?;
163 }
164 MirRelationExpr::Join {
165 inputs,
166 equivalences,
167 implementation,
168 ..
169 } => {
170 attempt_join_simplification(
171 inputs,
172 equivalences,
173 implementation,
174 let_replacements,
175 gets_behind_gets,
176 );
177 for input in inputs {
178 self.action(input, let_replacements, gets_behind_gets)?;
179 }
180 }
181 _ => {
182 for child in expr.children_mut() {
183 self.action(child, let_replacements, gets_behind_gets)?;
184 }
185 }
186 }
187 Ok::<(), crate::TransformError>(())
188 })
189 }
190}
191
192/// Attempt to simplify the join using local information and let bindings.
193fn attempt_join_simplification(
194 inputs: &mut [MirRelationExpr],
195 equivalences: &Vec<Vec<MirScalarExpr>>,
196 implementation: &mut mz_expr::JoinImplementation,
197 let_replacements: &BTreeMap<LocalId, Vec<Replacement>>,
198 gets_behind_gets: &BTreeMap<LocalId, Vec<(Id, Vec<MirScalarExpr>)>>,
199) {
200 // Useful join manipulation helper.
201 let input_mapper = JoinInputMapper::new(inputs);
202
203 if let Some((ltr, rtl)) = semijoin_bijection(inputs, equivalences) {
204 // If semijoin_bijection returns `Some(...)`, then `inputs.len() == 2`.
205 assert_eq!(inputs.len(), 2);
206
207 // Collect the `Get` identifiers each input might present as.
208 let ids0 = as_filtered_get(&inputs[0], gets_behind_gets)
209 .iter()
210 .map(|(id, _)| *id)
211 .collect::<Vec<_>>();
212 let ids1 = as_filtered_get(&inputs[1], gets_behind_gets)
213 .iter()
214 .map(|(id, _)| *id)
215 .collect::<Vec<_>>();
216
217 // Record the types of the inputs, for use in both loops below.
218 let typ0 = inputs[0].typ();
219 let typ1 = inputs[1].typ();
220
221 // Consider replacing the second input for the benefit of the first.
222 if distinct_on_keys_of(&typ1, &rtl) && input_mapper.input_arity(1) == equivalences.len() {
223 for mut candidate in list_replacements(&inputs[1], let_replacements, gets_behind_gets) {
224 if ids0.contains(&candidate.id) {
225 if let Some(permutation) = validate_replacement(<r, &mut candidate) {
226 inputs[1] = candidate.replacement.project(permutation);
227 *implementation = mz_expr::JoinImplementation::Unimplemented;
228
229 // Take a moment to think about pushing down `IS NOT NULL` tests.
230 // The pushdown is for the benefit of CSE on the `A` expressions,
231 // in the not uncommon case of nullable foreign keys in outer joins.
232 // TODO: Discover the transform that would not require this code.
233 let mut is_not_nulls = Vec::new();
234 for (col0, col1) in ltr.iter() {
235 // We are using the pre-computed types; recomputing the types here
236 // might alter nullability. As of 2025-01-09, Gábor has not found that
237 // happening. But for the future, notice that this could be a source of
238 // inaccurate or inconsistent nullability information.
239 if !typ1.column_types[*col1].nullable
240 && typ0.column_types[*col0].nullable
241 {
242 is_not_nulls.push(MirScalarExpr::column(*col0).call_is_null().not())
243 }
244 }
245 if !is_not_nulls.is_empty() {
246 // Canonicalize otherwise arbitrary predicate order.
247 is_not_nulls.sort();
248 inputs[0] = inputs[0].take_dangerous().filter(is_not_nulls);
249 }
250
251 // GTFO because things are now crazy.
252 return;
253 }
254 }
255 }
256 }
257 // Consider replacing the first input for the benefit of the second.
258 if distinct_on_keys_of(&typ0, <r) && input_mapper.input_arity(0) == equivalences.len() {
259 for mut candidate in list_replacements(&inputs[0], let_replacements, gets_behind_gets) {
260 if ids1.contains(&candidate.id) {
261 if let Some(permutation) = validate_replacement(&rtl, &mut candidate) {
262 inputs[0] = candidate.replacement.project(permutation);
263 *implementation = mz_expr::JoinImplementation::Unimplemented;
264
265 // Take a moment to think about pushing down `IS NOT NULL` tests.
266 // The pushdown is for the benefit of CSE on the `A` expressions,
267 // in the not uncommon case of nullable foreign keys in outer joins.
268 // TODO: Discover the transform that would not require this code.
269 let mut is_not_nulls = Vec::new();
270 for (col1, col0) in rtl.iter() {
271 if !typ0.column_types[*col0].nullable
272 && typ1.column_types[*col1].nullable
273 {
274 is_not_nulls.push(MirScalarExpr::column(*col1).call_is_null().not())
275 }
276 }
277 if !is_not_nulls.is_empty() {
278 inputs[1] = inputs[1].take_dangerous().filter(is_not_nulls);
279 }
280
281 // GTFO because things are now crazy.
282 return;
283 }
284 }
285 }
286 }
287 }
288}
289
290/// Evaluates the viability of a `candidate` to drive the replacement at `semijoin`.
291///
292/// Returns a projection to apply to `candidate.replacement` if everything checks out.
293fn validate_replacement(
294 map: &BTreeMap<usize, usize>,
295 candidate: &mut Replacement,
296) -> Option<Vec<usize>> {
297 if candidate.columns.len() == map.len()
298 && candidate
299 .columns
300 .iter()
301 .all(|(c0, c1, _c2)| map.get(c0) == Some(c1))
302 {
303 candidate.columns.sort_by_key(|(_, c, _)| *c);
304 Some(
305 candidate
306 .columns
307 .iter()
308 .map(|(_, _, c)| *c)
309 .collect::<Vec<_>>(),
310 )
311 } else {
312 None
313 }
314}
315
316/// A restricted form of a semijoin idempotence information.
317///
318/// A `Replacement` may be offered up by any `MirRelationExpr`, meant to be `B` from above or similar,
319/// and indicates that the offered expression can be projected onto columns such that it then exactly equals
320/// a column projection of `Get{id} semijoin replacement`.
321///
322/// Specifically,
323/// the `columns` member lists indexes `(a, b, c)` where column `b` of the offering expression corresponds to
324/// columns `a` in `Get{id}` and `c` in `replacement`, and for which the semijoin requires `a = c`. The values
325/// of the projection of the offering expression onto the `b` indexes exactly equal the intersection of the
326/// projection of `Get{id}` onto the `a` indexes and the projection of `replacement` onto the `c` columns.
327#[derive(Clone, Debug)]
328struct Replacement {
329 id: Id,
330 columns: Vec<(usize, usize, usize)>,
331 replacement: MirRelationExpr,
332}
333
334/// Return a list of potential semijoin replacements for `expr`.
335///
336/// This method descends recursively, traversing `Get`, `Project`, `Reduce`, and `ArrangeBy` operators
337/// looking for a `Join` operator, at which point it defers to the `list_replacements_join` method.
338fn list_replacements(
339 expr: &MirRelationExpr,
340 let_replacements: &BTreeMap<LocalId, Vec<Replacement>>,
341 gets_behind_gets: &BTreeMap<LocalId, Vec<(Id, Vec<MirScalarExpr>)>>,
342) -> Vec<Replacement> {
343 let mut results = Vec::new();
344 match expr {
345 MirRelationExpr::Get {
346 id: Id::Local(lid), ..
347 } => {
348 // The `Get` may reference an `id` that offers semijoin replacements.
349 if let Some(replacements) = let_replacements.get(lid) {
350 results.extend(replacements.iter().cloned());
351 }
352 }
353 MirRelationExpr::Join {
354 inputs,
355 equivalences,
356 ..
357 } => {
358 results.extend(list_replacements_join(
359 inputs,
360 equivalences,
361 gets_behind_gets,
362 ));
363 }
364 MirRelationExpr::Project { input, outputs } => {
365 // If the columns are preserved by projection ..
366 results.extend(
367 list_replacements(input, let_replacements, gets_behind_gets)
368 .into_iter()
369 .filter_map(|mut replacement| {
370 let new_cols = replacement
371 .columns
372 .iter()
373 .filter_map(|(c0, c1, c2)| {
374 outputs.iter().position(|o| o == c1).map(|c| (*c0, c, *c2))
375 })
376 .collect::<Vec<_>>();
377 if new_cols.len() == replacement.columns.len() {
378 replacement.columns = new_cols;
379 Some(replacement)
380 } else {
381 None
382 }
383 }),
384 );
385 }
386 MirRelationExpr::Reduce {
387 input, group_key, ..
388 } => {
389 // If the columns are preserved by `group_key` ..
390 results.extend(
391 list_replacements(input, let_replacements, gets_behind_gets)
392 .into_iter()
393 .filter_map(|mut replacement| {
394 let new_cols = replacement
395 .columns
396 .iter()
397 .filter_map(|(c0, c1, c2)| {
398 group_key
399 .iter()
400 .position(|o| o.as_column() == Some(*c1))
401 .map(|c| (*c0, c, *c2))
402 })
403 .collect::<Vec<_>>();
404 if new_cols.len() == replacement.columns.len() {
405 replacement.columns = new_cols;
406 Some(replacement)
407 } else {
408 None
409 }
410 }),
411 );
412 }
413 MirRelationExpr::ArrangeBy { input, .. } => {
414 results.extend(list_replacements(input, let_replacements, gets_behind_gets));
415 }
416 _ => {}
417 }
418 results
419}
420
421/// Return a list of potential semijoin replacements for `expr`.
422fn list_replacements_join(
423 inputs: &[MirRelationExpr],
424 equivalences: &Vec<Vec<MirScalarExpr>>,
425 gets_behind_gets: &BTreeMap<LocalId, Vec<(Id, Vec<MirScalarExpr>)>>,
426) -> Vec<Replacement> {
427 // Result replacements.
428 let mut results = Vec::new();
429
430 // If we are a binary join whose equivalence classes equate columns in the two inputs.
431 if let Some((ltr, rtl)) = semijoin_bijection(inputs, equivalences) {
432 // Each unique key could be a semijoin candidate.
433 // We want to check that the join equivalences exactly match the key,
434 // and then transcribe the corresponding columns in the other input.
435 if distinct_on_keys_of(&inputs[1].typ(), &rtl) {
436 let columns = ltr
437 .iter()
438 .map(|(k0, k1)| (*k0, *k0, *k1))
439 .collect::<Vec<_>>();
440
441 for (id, mut predicates) in as_filtered_get(&inputs[0], gets_behind_gets) {
442 if predicates
443 .iter()
444 .all(|e| e.support().iter().all(|c| ltr.contains_key(c)))
445 {
446 for predicate in predicates.iter_mut() {
447 predicate.permute_map(<r);
448 }
449
450 let mut replacement = inputs[1].clone();
451 if !predicates.is_empty() {
452 replacement = replacement.filter(predicates.clone());
453 }
454 results.push(Replacement {
455 id,
456 columns: columns.clone(),
457 replacement,
458 })
459 }
460 }
461 }
462 // Each unique key could be a semijoin candidate.
463 // We want to check that the join equivalences exactly match the key,
464 // and then transcribe the corresponding columns in the other input.
465 if distinct_on_keys_of(&inputs[0].typ(), <r) {
466 let columns = ltr
467 .iter()
468 .map(|(k0, k1)| (*k1, *k0, *k0))
469 .collect::<Vec<_>>();
470
471 for (id, mut predicates) in as_filtered_get(&inputs[1], gets_behind_gets) {
472 if predicates
473 .iter()
474 .all(|e| e.support().iter().all(|c| rtl.contains_key(c)))
475 {
476 for predicate in predicates.iter_mut() {
477 predicate.permute_map(&rtl);
478 }
479
480 let mut replacement = inputs[0].clone();
481 if !predicates.is_empty() {
482 replacement = replacement.filter(predicates.clone());
483 }
484 results.push(Replacement {
485 id,
486 columns: columns.clone(),
487 replacement,
488 })
489 }
490 }
491 }
492 }
493
494 results
495}
496
497/// True iff some unique key of `typ` is contained in the keys of `map`.
498fn distinct_on_keys_of(typ: &ReprRelationType, map: &BTreeMap<usize, usize>) -> bool {
499 typ.keys
500 .iter()
501 .any(|key| key.iter().all(|k| map.contains_key(k)))
502}
503
504/// Attempts to interpret `expr` as filters applied to a `Get`.
505///
506/// Returns a list of such interpretations, potentially spanning `Let` bindings.
507fn as_filtered_get(
508 mut expr: &MirRelationExpr,
509 gets_behind_gets: &BTreeMap<LocalId, Vec<(Id, Vec<MirScalarExpr>)>>,
510) -> Vec<(Id, Vec<MirScalarExpr>)> {
511 let mut results = Vec::new();
512 while let MirRelationExpr::Filter { input, predicates } = expr {
513 results.extend(predicates.iter().cloned());
514 expr = &**input;
515 }
516 if let MirRelationExpr::Get { id, .. } = expr {
517 let mut output = Vec::new();
518 if let Id::Local(lid) = id {
519 if let Some(bound) = gets_behind_gets.get(lid) {
520 for (id, list) in bound.iter() {
521 let mut predicates = list.clone();
522 predicates.extend(results.iter().cloned());
523 output.push((*id, predicates));
524 }
525 }
526 }
527 output.push((*id, results));
528 output
529 } else {
530 Vec::new()
531 }
532}
533
534/// Determines bijection between equated columns of a binary join.
535///
536/// Returns nothing if not a binary join, or if any equivalences are not of two opposing columns.
537/// Returned maps go from the column of the first input to those of the second, and vice versa.
538fn semijoin_bijection(
539 inputs: &[MirRelationExpr],
540 equivalences: &Vec<Vec<MirScalarExpr>>,
541) -> Option<(BTreeMap<usize, usize>, BTreeMap<usize, usize>)> {
542 // Useful join manipulation helper.
543 let input_mapper = JoinInputMapper::new(inputs);
544
545 // Pairs of equated columns localized to inputs 0 and 1.
546 let mut equiv_pairs = Vec::with_capacity(equivalences.len());
547
548 // Populate `equiv_pairs`, ideally finding exactly one pair for each equivalence class.
549 // TODO(mgree) !!! store the column names
550 for eq in equivalences.iter() {
551 if eq.len() == 2 {
552 // The equivalence class could reference the inputs in either order, or be some
553 // tangle of references (e.g. to both) that we want to avoid reacting to.
554 match (
555 input_mapper.single_input(&eq[0]),
556 input_mapper.single_input(&eq[1]),
557 ) {
558 (Some(0), Some(1)) => {
559 let expr0 = input_mapper.map_expr_to_local(eq[0].clone());
560 let expr1 = input_mapper.map_expr_to_local(eq[1].clone());
561 if let (
562 MirScalarExpr::Column(col0, _name0),
563 MirScalarExpr::Column(col1, _name1),
564 ) = (expr0, expr1)
565 {
566 equiv_pairs.push((col0, col1));
567 }
568 }
569 (Some(1), Some(0)) => {
570 let expr0 = input_mapper.map_expr_to_local(eq[1].clone());
571 let expr1 = input_mapper.map_expr_to_local(eq[0].clone());
572 if let (
573 MirScalarExpr::Column(col0, _name0),
574 MirScalarExpr::Column(col1, _name1),
575 ) = (expr0, expr1)
576 {
577 equiv_pairs.push((col0, col1));
578 }
579 }
580 _ => {}
581 }
582 }
583 }
584
585 if inputs.len() == 2 && equiv_pairs.len() == equivalences.len() {
586 let ltr = equiv_pairs.iter().cloned().collect();
587 let rtl = equiv_pairs.iter().map(|(c0, c1)| (*c1, *c0)).collect();
588
589 Some((ltr, rtl))
590 } else {
591 None
592 }
593}