mz_transform/semijoin_idempotence.rs
1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Remove semijoins that are applied multiple times to no further effect.
11//!
12//! Mechanically, this transform looks for instances of `A join B` and replaces
13//! `B` with a simpler `C`. It does this in the restricted setting that each `join`
14//! would be a "semijoin": a multiplicity preserving restriction.
15//!
16//! The approach we use here is to restrict our attention to cases where
17//!
18//! 1. `A` is a potentially filtered instance of some `Get{id}`,
19//! 2. `A join B` equate columns of `A` to all columns of `B`,
20//! 3. The multiplicity of any record in `B` is at most one.
21//! 4. The values in these records are exactly `Get{id} join C`.
22//!
23//! We find a candidate `C` by descending `B` looking for another semijoin between
24//! `Get{id}` and some other collection `D` on the same columns as `A` means to join `B`.
25//! Should we find such, allowing arbitrary filters of `Get{id}` on the equated columns,
26//! which we will transfer to the columns of `D` thereby forming `C`.
27
28use itertools::Itertools;
29use mz_repr::RelationType;
30use std::collections::BTreeMap;
31
32use mz_expr::{Id, JoinInputMapper, LocalId, MirRelationExpr, MirScalarExpr, RECURSION_LIMIT};
33use mz_ore::id_gen::IdGen;
34use mz_ore::stack::{CheckedRecursion, RecursionGuard};
35
36use crate::TransformCtx;
37
38/// Remove redundant semijoin operators
39#[derive(Debug)]
40pub struct SemijoinIdempotence {
41 recursion_guard: RecursionGuard,
42}
43
44impl Default for SemijoinIdempotence {
45 fn default() -> SemijoinIdempotence {
46 SemijoinIdempotence {
47 recursion_guard: RecursionGuard::with_limit(RECURSION_LIMIT),
48 }
49 }
50}
51
52impl CheckedRecursion for SemijoinIdempotence {
53 fn recursion_guard(&self) -> &RecursionGuard {
54 &self.recursion_guard
55 }
56}
57
58impl crate::Transform for SemijoinIdempotence {
59 fn name(&self) -> &'static str {
60 "SemijoinIdempotence"
61 }
62
63 #[mz_ore::instrument(
64 target = "optimizer",
65 level = "debug",
66 fields(path.segment = "semijoin_idempotence")
67 )]
68 fn actually_perform_transform(
69 &self,
70 relation: &mut MirRelationExpr,
71 _: &mut TransformCtx,
72 ) -> Result<(), crate::TransformError> {
73 // We need to call `renumber_bindings` because we will call
74 // `MirRelationExpr::collect_expirations`, which relies on this invariant.
75 crate::normalize_lets::renumber_bindings(relation, &mut IdGen::default())?;
76
77 let mut let_replacements = BTreeMap::<LocalId, Vec<Replacement>>::new();
78 let mut gets_behind_gets = BTreeMap::<LocalId, Vec<(Id, Vec<MirScalarExpr>)>>::new();
79 self.action(relation, &mut let_replacements, &mut gets_behind_gets)?;
80
81 mz_repr::explain::trace_plan(&*relation);
82 Ok(())
83 }
84}
85
86impl SemijoinIdempotence {
87 /// * `let_replacements` - `Replacement`s offered up by CTEs.
88 /// * `gets_behind_gets` - The result of `as_filtered_get` called on CTEs.
89 fn action(
90 &self,
91 expr: &mut MirRelationExpr,
92 let_replacements: &mut BTreeMap<LocalId, Vec<Replacement>>,
93 gets_behind_gets: &mut BTreeMap<LocalId, Vec<(Id, Vec<MirScalarExpr>)>>,
94 ) -> Result<(), crate::TransformError> {
95 // At each node, either gather info about Let bindings or attempt to simplify a join.
96 self.checked_recur(move |_| {
97 match expr {
98 MirRelationExpr::Let { id, value, body } => {
99 let_replacements.insert(
100 *id,
101 list_replacements(&*value, let_replacements, gets_behind_gets),
102 );
103 gets_behind_gets.insert(*id, as_filtered_get(value, gets_behind_gets));
104 self.action(value, let_replacements, gets_behind_gets)?;
105 self.action(body, let_replacements, gets_behind_gets)?;
106 // No need to do expirations here, as there is only one CTE (and it can't be
107 // recursive).
108 }
109 MirRelationExpr::LetRec {
110 ids,
111 values,
112 limits: _,
113 body,
114 } => {
115 // Expirations. See comments on `collect_expirations` and `do_expirations`.
116 // Note that `expirations` is local to one `LetRec`, because a `LetRec` can't
117 // reference something that is defined in an inner `LetRec`, so a definition in
118 // an inner `LetRec` can't expire something from an outer `LetRec`.
119 let mut expirations = BTreeMap::new();
120 for (id, value) in ids.iter().zip_eq(values.iter_mut()) {
121 // 1. Recursive call. This has to be before 2. to avoid problems when a
122 // binding refers to itself.
123 self.action(value, let_replacements, gets_behind_gets)?;
124
125 // 2. Gather info from the `value` for use in later bindings and the body.
126 let replacements_from_value =
127 list_replacements(&*value, let_replacements, gets_behind_gets);
128 let_replacements.insert(*id, replacements_from_value.clone());
129 let value_as_filtered_gets = as_filtered_get(value, gets_behind_gets);
130 gets_behind_gets.insert(*id, value_as_filtered_gets.clone());
131
132 // 3. Collect expirations.
133 for replacement in replacements_from_value {
134 MirRelationExpr::collect_expirations(
135 *id,
136 &replacement.replacement,
137 &mut expirations,
138 );
139 }
140 for referenced_id in
141 value_as_filtered_gets
142 .iter()
143 .filter_map(|(id, _filter)| match id {
144 Id::Local(lid) => Some(lid),
145 _ => None,
146 })
147 {
148 if referenced_id >= id {
149 expirations
150 .entry(*referenced_id)
151 .or_insert_with(Vec::new)
152 .push(*id);
153 }
154 }
155
156 // 4. Perform expirations.
157 MirRelationExpr::do_expirations(*id, &mut expirations, let_replacements);
158 MirRelationExpr::do_expirations(*id, &mut expirations, gets_behind_gets);
159 }
160 self.action(body, let_replacements, gets_behind_gets)?;
161 }
162 MirRelationExpr::Join {
163 inputs,
164 equivalences,
165 implementation,
166 ..
167 } => {
168 attempt_join_simplification(
169 inputs,
170 equivalences,
171 implementation,
172 let_replacements,
173 gets_behind_gets,
174 );
175 for input in inputs {
176 self.action(input, let_replacements, gets_behind_gets)?;
177 }
178 }
179 _ => {
180 for child in expr.children_mut() {
181 self.action(child, let_replacements, gets_behind_gets)?;
182 }
183 }
184 }
185 Ok::<(), crate::TransformError>(())
186 })
187 }
188}
189
190/// Attempt to simplify the join using local information and let bindings.
191fn attempt_join_simplification(
192 inputs: &mut [MirRelationExpr],
193 equivalences: &Vec<Vec<MirScalarExpr>>,
194 implementation: &mut mz_expr::JoinImplementation,
195 let_replacements: &BTreeMap<LocalId, Vec<Replacement>>,
196 gets_behind_gets: &BTreeMap<LocalId, Vec<(Id, Vec<MirScalarExpr>)>>,
197) {
198 // Useful join manipulation helper.
199 let input_mapper = JoinInputMapper::new(inputs);
200
201 if let Some((ltr, rtl)) = semijoin_bijection(inputs, equivalences) {
202 // If semijoin_bijection returns `Some(...)`, then `inputs.len() == 2`.
203 assert_eq!(inputs.len(), 2);
204
205 // Collect the `Get` identifiers each input might present as.
206 let ids0 = as_filtered_get(&inputs[0], gets_behind_gets)
207 .iter()
208 .map(|(id, _)| *id)
209 .collect::<Vec<_>>();
210 let ids1 = as_filtered_get(&inputs[1], gets_behind_gets)
211 .iter()
212 .map(|(id, _)| *id)
213 .collect::<Vec<_>>();
214
215 // Record the types of the inputs, for use in both loops below.
216 let typ0 = inputs[0].typ();
217 let typ1 = inputs[1].typ();
218
219 // Consider replacing the second input for the benefit of the first.
220 if distinct_on_keys_of(&typ1, &rtl) && input_mapper.input_arity(1) == equivalences.len() {
221 for mut candidate in list_replacements(&inputs[1], let_replacements, gets_behind_gets) {
222 if ids0.contains(&candidate.id) {
223 if let Some(permutation) = validate_replacement(<r, &mut candidate) {
224 inputs[1] = candidate.replacement.project(permutation);
225 *implementation = mz_expr::JoinImplementation::Unimplemented;
226
227 // Take a moment to think about pushing down `IS NOT NULL` tests.
228 // The pushdown is for the benefit of CSE on the `A` expressions,
229 // in the not uncommon case of nullable foreign keys in outer joins.
230 // TODO: Discover the transform that would not require this code.
231 let mut is_not_nulls = Vec::new();
232 for (col0, col1) in ltr.iter() {
233 // We are using the pre-computed types; recomputing the types here
234 // might alter nullability. As of 2025-01-09, Gábor has not found that
235 // happening. But for the future, notice that this could be a source of
236 // inaccurate or inconsistent nullability information.
237 if !typ1.column_types[*col1].nullable
238 && typ0.column_types[*col0].nullable
239 {
240 is_not_nulls.push(MirScalarExpr::Column(*col0).call_is_null().not())
241 }
242 }
243 if !is_not_nulls.is_empty() {
244 // Canonicalize otherwise arbitrary predicate order.
245 is_not_nulls.sort();
246 inputs[0] = inputs[0].take_dangerous().filter(is_not_nulls);
247 }
248
249 // GTFO because things are now crazy.
250 return;
251 }
252 }
253 }
254 }
255 // Consider replacing the first input for the benefit of the second.
256 if distinct_on_keys_of(&typ0, <r) && input_mapper.input_arity(0) == equivalences.len() {
257 for mut candidate in list_replacements(&inputs[0], let_replacements, gets_behind_gets) {
258 if ids1.contains(&candidate.id) {
259 if let Some(permutation) = validate_replacement(&rtl, &mut candidate) {
260 inputs[0] = candidate.replacement.project(permutation);
261 *implementation = mz_expr::JoinImplementation::Unimplemented;
262
263 // Take a moment to think about pushing down `IS NOT NULL` tests.
264 // The pushdown is for the benefit of CSE on the `A` expressions,
265 // in the not uncommon case of nullable foreign keys in outer joins.
266 // TODO: Discover the transform that would not require this code.
267 let mut is_not_nulls = Vec::new();
268 for (col1, col0) in rtl.iter() {
269 if !typ0.column_types[*col0].nullable
270 && typ1.column_types[*col1].nullable
271 {
272 is_not_nulls.push(MirScalarExpr::Column(*col1).call_is_null().not())
273 }
274 }
275 if !is_not_nulls.is_empty() {
276 inputs[1] = inputs[1].take_dangerous().filter(is_not_nulls);
277 }
278
279 // GTFO because things are now crazy.
280 return;
281 }
282 }
283 }
284 }
285 }
286}
287
288/// Evaluates the viability of a `candidate` to drive the replacement at `semijoin`.
289///
290/// Returns a projection to apply to `candidate.replacement` if everything checks out.
291fn validate_replacement(
292 map: &BTreeMap<usize, usize>,
293 candidate: &mut Replacement,
294) -> Option<Vec<usize>> {
295 if candidate.columns.len() == map.len()
296 && candidate
297 .columns
298 .iter()
299 .all(|(c0, c1, _c2)| map.get(c0) == Some(c1))
300 {
301 candidate.columns.sort_by_key(|(_, c, _)| *c);
302 Some(
303 candidate
304 .columns
305 .iter()
306 .map(|(_, _, c)| *c)
307 .collect::<Vec<_>>(),
308 )
309 } else {
310 None
311 }
312}
313
314/// A restricted form of a semijoin idempotence information.
315///
316/// A `Replacement` may be offered up by any `MirRelationExpr`, meant to be `B` from above or similar,
317/// and indicates that the offered expression can be projected onto columns such that it then exactly equals
318/// a column projection of `Get{id} semijoin replacement`.
319///
320/// Specifically,
321/// the `columns` member lists indexes `(a, b, c)` where column `b` of the offering expression corresponds to
322/// columns `a` in `Get{id}` and `c` in `replacement`, and for which the semijoin requires `a = c`. The values
323/// of the projection of the offering expression onto the `b` indexes exactly equal the intersection of the
324/// projection of `Get{id}` onto the `a` indexes and the projection of `replacement` onto the `c` columns.
325#[derive(Clone, Debug)]
326struct Replacement {
327 id: Id,
328 columns: Vec<(usize, usize, usize)>,
329 replacement: MirRelationExpr,
330}
331
332/// Return a list of potential semijoin replacements for `expr`.
333///
334/// This method descends recursively, traversing `Get`, `Project`, `Reduce`, and `ArrangeBy` operators
335/// looking for a `Join` operator, at which point it defers to the `list_replacements_join` method.
336fn list_replacements(
337 expr: &MirRelationExpr,
338 let_replacements: &BTreeMap<LocalId, Vec<Replacement>>,
339 gets_behind_gets: &BTreeMap<LocalId, Vec<(Id, Vec<MirScalarExpr>)>>,
340) -> Vec<Replacement> {
341 let mut results = Vec::new();
342 match expr {
343 MirRelationExpr::Get {
344 id: Id::Local(lid), ..
345 } => {
346 // The `Get` may reference an `id` that offers semijoin replacements.
347 if let Some(replacements) = let_replacements.get(lid) {
348 results.extend(replacements.iter().cloned());
349 }
350 }
351 MirRelationExpr::Join {
352 inputs,
353 equivalences,
354 ..
355 } => {
356 results.extend(list_replacements_join(
357 inputs,
358 equivalences,
359 gets_behind_gets,
360 ));
361 }
362 MirRelationExpr::Project { input, outputs } => {
363 // If the columns are preserved by projection ..
364 results.extend(
365 list_replacements(input, let_replacements, gets_behind_gets)
366 .into_iter()
367 .filter_map(|mut replacement| {
368 let new_cols = replacement
369 .columns
370 .iter()
371 .filter_map(|(c0, c1, c2)| {
372 outputs.iter().position(|o| o == c1).map(|c| (*c0, c, *c2))
373 })
374 .collect::<Vec<_>>();
375 if new_cols.len() == replacement.columns.len() {
376 replacement.columns = new_cols;
377 Some(replacement)
378 } else {
379 None
380 }
381 }),
382 );
383 }
384 MirRelationExpr::Reduce {
385 input, group_key, ..
386 } => {
387 // If the columns are preserved by `group_key` ..
388 results.extend(
389 list_replacements(input, let_replacements, gets_behind_gets)
390 .into_iter()
391 .filter_map(|mut replacement| {
392 let new_cols = replacement
393 .columns
394 .iter()
395 .filter_map(|(c0, c1, c2)| {
396 group_key
397 .iter()
398 .position(|o| o == &MirScalarExpr::Column(*c1))
399 .map(|c| (*c0, c, *c2))
400 })
401 .collect::<Vec<_>>();
402 if new_cols.len() == replacement.columns.len() {
403 replacement.columns = new_cols;
404 Some(replacement)
405 } else {
406 None
407 }
408 }),
409 );
410 }
411 MirRelationExpr::ArrangeBy { input, .. } => {
412 results.extend(list_replacements(input, let_replacements, gets_behind_gets));
413 }
414 _ => {}
415 }
416 results
417}
418
419/// Return a list of potential semijoin replacements for `expr`.
420fn list_replacements_join(
421 inputs: &[MirRelationExpr],
422 equivalences: &Vec<Vec<MirScalarExpr>>,
423 gets_behind_gets: &BTreeMap<LocalId, Vec<(Id, Vec<MirScalarExpr>)>>,
424) -> Vec<Replacement> {
425 // Result replacements.
426 let mut results = Vec::new();
427
428 // If we are a binary join whose equivalence classes equate columns in the two inputs.
429 if let Some((ltr, rtl)) = semijoin_bijection(inputs, equivalences) {
430 // Each unique key could be a semijoin candidate.
431 // We want to check that the join equivalences exactly match the key,
432 // and then transcribe the corresponding columns in the other input.
433 if distinct_on_keys_of(&inputs[1].typ(), &rtl) {
434 let columns = ltr
435 .iter()
436 .map(|(k0, k1)| (*k0, *k0, *k1))
437 .collect::<Vec<_>>();
438
439 for (id, mut predicates) in as_filtered_get(&inputs[0], gets_behind_gets) {
440 if predicates
441 .iter()
442 .all(|e| e.support().iter().all(|c| ltr.contains_key(c)))
443 {
444 for predicate in predicates.iter_mut() {
445 predicate.permute_map(<r);
446 }
447
448 let mut replacement = inputs[1].clone();
449 if !predicates.is_empty() {
450 replacement = replacement.filter(predicates.clone());
451 }
452 results.push(Replacement {
453 id,
454 columns: columns.clone(),
455 replacement,
456 })
457 }
458 }
459 }
460 // Each unique key could be a semijoin candidate.
461 // We want to check that the join equivalences exactly match the key,
462 // and then transcribe the corresponding columns in the other input.
463 if distinct_on_keys_of(&inputs[0].typ(), <r) {
464 let columns = ltr
465 .iter()
466 .map(|(k0, k1)| (*k1, *k0, *k0))
467 .collect::<Vec<_>>();
468
469 for (id, mut predicates) in as_filtered_get(&inputs[1], gets_behind_gets) {
470 if predicates
471 .iter()
472 .all(|e| e.support().iter().all(|c| rtl.contains_key(c)))
473 {
474 for predicate in predicates.iter_mut() {
475 predicate.permute_map(&rtl);
476 }
477
478 let mut replacement = inputs[0].clone();
479 if !predicates.is_empty() {
480 replacement = replacement.filter(predicates.clone());
481 }
482 results.push(Replacement {
483 id,
484 columns: columns.clone(),
485 replacement,
486 })
487 }
488 }
489 }
490 }
491
492 results
493}
494
495/// True iff some unique key of `typ` is contained in the keys of `map`.
496fn distinct_on_keys_of(typ: &RelationType, map: &BTreeMap<usize, usize>) -> bool {
497 typ.keys
498 .iter()
499 .any(|key| key.iter().all(|k| map.contains_key(k)))
500}
501
502/// Attempts to interpret `expr` as filters applied to a `Get`.
503///
504/// Returns a list of such interpretations, potentially spanning `Let` bindings.
505fn as_filtered_get(
506 mut expr: &MirRelationExpr,
507 gets_behind_gets: &BTreeMap<LocalId, Vec<(Id, Vec<MirScalarExpr>)>>,
508) -> Vec<(Id, Vec<MirScalarExpr>)> {
509 let mut results = Vec::new();
510 while let MirRelationExpr::Filter { input, predicates } = expr {
511 results.extend(predicates.iter().cloned());
512 expr = &**input;
513 }
514 if let MirRelationExpr::Get { id, .. } = expr {
515 let mut output = Vec::new();
516 if let Id::Local(lid) = id {
517 if let Some(bound) = gets_behind_gets.get(lid) {
518 for (id, list) in bound.iter() {
519 let mut predicates = list.clone();
520 predicates.extend(results.iter().cloned());
521 output.push((*id, predicates));
522 }
523 }
524 }
525 output.push((*id, results));
526 output
527 } else {
528 Vec::new()
529 }
530}
531
532/// Determines bijection between equated columns of a binary join.
533///
534/// Returns nothing if not a binary join, or if any equivalences are not of two opposing columns.
535/// Returned maps go from the column of the first input to those of the second, and vice versa.
536fn semijoin_bijection(
537 inputs: &[MirRelationExpr],
538 equivalences: &Vec<Vec<MirScalarExpr>>,
539) -> Option<(BTreeMap<usize, usize>, BTreeMap<usize, usize>)> {
540 // Useful join manipulation helper.
541 let input_mapper = JoinInputMapper::new(inputs);
542
543 // Pairs of equated columns localized to inputs 0 and 1.
544 let mut equiv_pairs = Vec::with_capacity(equivalences.len());
545
546 // Populate `equiv_pairs`, ideally finding exactly one pair for each equivalence class.
547 for eq in equivalences.iter() {
548 if eq.len() == 2 {
549 // The equivalence class could reference the inputs in either order, or be some
550 // tangle of references (e.g. to both) that we want to avoid reacting to.
551 match (
552 input_mapper.single_input(&eq[0]),
553 input_mapper.single_input(&eq[1]),
554 ) {
555 (Some(0), Some(1)) => {
556 let expr0 = input_mapper.map_expr_to_local(eq[0].clone());
557 let expr1 = input_mapper.map_expr_to_local(eq[1].clone());
558 if let (MirScalarExpr::Column(col0), MirScalarExpr::Column(col1)) =
559 (expr0, expr1)
560 {
561 equiv_pairs.push((col0, col1));
562 }
563 }
564 (Some(1), Some(0)) => {
565 let expr0 = input_mapper.map_expr_to_local(eq[1].clone());
566 let expr1 = input_mapper.map_expr_to_local(eq[0].clone());
567 if let (MirScalarExpr::Column(col0), MirScalarExpr::Column(col1)) =
568 (expr0, expr1)
569 {
570 equiv_pairs.push((col0, col1));
571 }
572 }
573 _ => {}
574 }
575 }
576 }
577
578 if inputs.len() == 2 && equiv_pairs.len() == equivalences.len() {
579 let ltr = equiv_pairs.iter().cloned().collect();
580 let rtl = equiv_pairs.iter().map(|(c0, c1)| (*c1, *c0)).collect();
581
582 Some((ltr, rtl))
583 } else {
584 None
585 }
586}