mz_transform/fusion/
join.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Fuses multiple `Join` operators into one `Join` operator.
11//!
12//! Multiway join planning relies on a broad view of the involved relations,
13//! and chains of binary joins can make this challenging to reason about.
14//! Collecting multiple joins together with their constraints improves
15//! our ability to plan these joins, and reason about other operators' motion
16//! around them.
17//!
18//! Also removes unit collections from joins, and joins with fewer than two inputs.
19//!
20//! Unit collections have no columns and a count of one, and a join with such
21//! a collection act as the identity operator on collections. Once removed,
22//! we may find joins with zero or one input, which can be further simplified.
23
24use std::collections::BTreeMap;
25
26use mz_expr::visit::Visit;
27use mz_expr::{BinaryFunc, VariadicFunc};
28use mz_expr::{MapFilterProject, MirRelationExpr, MirScalarExpr};
29
30use crate::analysis::equivalences::EquivalenceClasses;
31use crate::canonicalize_mfp::CanonicalizeMfp;
32use crate::predicate_pushdown::PredicatePushdown;
33use crate::{TransformCtx, TransformError};
34
35/// Fuses multiple `Join` operators into one `Join` operator.
36///
37/// Removes unit collections from joins, and joins with fewer than two inputs.
38/// Filters on top of nested joins are lifted so the nested joins can be fused.
39#[derive(Debug)]
40pub struct Join;
41
42impl crate::Transform for Join {
43    fn name(&self) -> &'static str {
44        "JoinFusion"
45    }
46
47    #[mz_ore::instrument(
48        target = "optimizer",
49        level = "debug",
50        fields(path.segment = "join_fusion")
51    )]
52    fn actually_perform_transform(
53        &self,
54        relation: &mut MirRelationExpr,
55        _: &mut TransformCtx,
56    ) -> Result<(), TransformError> {
57        // We need to stick with post-order here because `action` only fuses a
58        // Join with its direct children. This means that we can only fuse a
59        // tree of Join nodes in a single pass if we work bottom-up.
60        let mut transformed = false;
61        relation.try_visit_mut_post(&mut |relation| {
62            transformed |= Self::action(relation)?;
63            Ok::<_, TransformError>(())
64        })?;
65        // If the action applied in the non-trivial case, run PredicatePushdown
66        // and CanonicalizeMfp in order to re-construct an equi-Join which would
67        // be de-constructed as a Filter + CrossJoin by the action application.
68        //
69        // TODO(database-issues#7728): This is a temporary solution which fixes the "Product
70        // limits" issue observed in a failed Nightly run when the PR was first
71        // tested (https://buildkite.com/materialize/nightly/builds/6670). We
72        // should re-evaluate if we need this ad-hoc re-normalization step when
73        // LiteralLifting is removed in favor of EquivalencePropagation.
74        if transformed {
75            PredicatePushdown::default().action(relation, &mut BTreeMap::new())?;
76            CanonicalizeMfp.action(relation)?
77        }
78        mz_repr::explain::trace_plan(&*relation);
79        Ok(())
80    }
81}
82
83impl Join {
84    /// Fuses multiple `Join` operators into one `Join` operator.
85    ///
86    /// Return Ok(true) iff the action manipulated the tree after detecting the
87    /// most general pattern.
88    pub fn action(relation: &mut MirRelationExpr) -> Result<bool, TransformError> {
89        if let MirRelationExpr::Join {
90            inputs,
91            equivalences,
92            ..
93        } = relation
94        {
95            // Local non-fusion tidying.
96            inputs.retain(|e| !e.is_constant_singleton());
97            if inputs.len() == 0 {
98                *relation = MirRelationExpr::constant(vec![vec![]], mz_repr::RelationType::empty())
99                    .filter(unpack_equivalences(equivalences));
100                return Ok(false);
101            }
102            if inputs.len() == 1 {
103                *relation = inputs
104                    .pop()
105                    .unwrap()
106                    .filter(unpack_equivalences(equivalences));
107                return Ok(false);
108            }
109
110            // Bail early if no children are MFPs around a Join
111            if inputs.iter().any(|mut expr| {
112                let mut result = None;
113                while result.is_none() {
114                    match expr {
115                        MirRelationExpr::Map { input, .. }
116                        | MirRelationExpr::Filter { input, .. }
117                        | MirRelationExpr::Project { input, .. } => {
118                            expr = &**input;
119                        }
120                        MirRelationExpr::Join { .. } => {
121                            result = Some(true);
122                        }
123                        _ => {
124                            result = Some(false);
125                        }
126                    }
127                }
128                result.unwrap()
129            }) {
130                // Each input is either an MFP around a Join, or just an expression.
131                let children = inputs
132                    .iter()
133                    .map(|expr| {
134                        let (mfp, inner) = MapFilterProject::extract_from_expression(expr);
135                        if let MirRelationExpr::Join {
136                            inputs,
137                            equivalences,
138                            ..
139                        } = inner
140                        {
141                            Ok((mfp, (inputs, equivalences)))
142                        } else {
143                            Err((mfp.projection.len(), expr))
144                        }
145                    })
146                    .collect::<Vec<_>>();
147
148                // Our plan is to append all subjoin inputs, and non-join expressions.
149                // Each join will lift its MFP to act on the whole product (via arity).
150                // The final join will also be wrapped with `equivalences` as predicates.
151
152                let mut outer_arity = children
153                    .iter()
154                    .map(|child| match child {
155                        Ok((mfp, _)) => mfp.input_arity,
156                        Err((arity, _)) => *arity,
157                    })
158                    .sum();
159
160                // We will accumulate the lifted transformations here.
161                let mut outer_mfp = MapFilterProject::new(outer_arity);
162
163                let mut arity_so_far = 0;
164
165                let mut new_inputs = Vec::new();
166                for child in children.into_iter() {
167                    match child {
168                        Ok((mut mfp, (inputs, equivalences))) => {
169                            // Add the join inputs to the new join inputs.
170                            new_inputs.extend(inputs.iter().cloned());
171
172                            mfp.optimize();
173                            let (mut map, mut filter, mut project) = mfp.as_map_filter_project();
174                            filter.extend(unpack_equivalences(equivalences));
175                            // We need to rewrite column references in map and filter.
176                            // the applied map elements will be at the end, starting at `outer_arity`.
177                            for expr in map.iter_mut() {
178                                expr.visit_pre_mut(|e| {
179                                    if let MirScalarExpr::Column(c) = e {
180                                        if *c >= mfp.input_arity {
181                                            *c -= mfp.input_arity;
182                                            *c += outer_arity;
183                                        } else {
184                                            *c += arity_so_far;
185                                        }
186                                    }
187                                });
188                            }
189                            for expr in filter.iter_mut() {
190                                expr.visit_pre_mut(|e| {
191                                    if let MirScalarExpr::Column(c) = e {
192                                        if *c >= mfp.input_arity {
193                                            *c -= mfp.input_arity;
194                                            *c += outer_arity;
195                                        } else {
196                                            *c += arity_so_far;
197                                        }
198                                    }
199                                });
200                            }
201                            for c in project.iter_mut() {
202                                if *c >= mfp.input_arity {
203                                    *c -= mfp.input_arity;
204                                    *c += outer_arity;
205                                } else {
206                                    *c += arity_so_far;
207                                }
208                            }
209
210                            outer_mfp = outer_mfp.map(map.clone());
211                            outer_mfp = outer_mfp.filter(filter);
212                            let projection = (0..arity_so_far)
213                                .chain(project.clone())
214                                .chain(arity_so_far + mfp.input_arity..outer_arity)
215                                .collect::<Vec<_>>();
216                            outer_mfp = outer_mfp.project(projection);
217
218                            outer_arity += project.len();
219                            outer_arity -= mfp.input_arity;
220                            arity_so_far += project.len();
221                        }
222                        Err((arity, expr)) => {
223                            new_inputs.push((*expr).clone());
224                            arity_so_far += arity;
225                        }
226                    }
227                }
228
229                new_inputs.retain(|e| !e.is_constant_singleton());
230
231                outer_mfp = outer_mfp.filter(unpack_equivalences(equivalences));
232                outer_mfp.optimize();
233                let (map, filter, project) = outer_mfp.as_map_filter_project();
234
235                *relation = match new_inputs.len() {
236                    0 => MirRelationExpr::constant(vec![vec![]], mz_repr::RelationType::empty()),
237                    1 => new_inputs.pop().unwrap(),
238                    _ => MirRelationExpr::join(new_inputs, Vec::new()),
239                }
240                .map(map)
241                .filter(filter)
242                .project(project);
243
244                return Ok(true);
245            }
246        }
247
248        Ok(false)
249    }
250}
251
252/// Unpacks multiple equivalence classes into conjuncts that should all be true, essentially
253/// turning join equivalences into a Filter.
254///
255/// Note that a join equivalence treats null equal to null, while an `=` in a Filter does not.
256/// This function is mindful of this.
257fn unpack_equivalences(equivalences: &Vec<Vec<MirScalarExpr>>) -> Vec<MirScalarExpr> {
258    let mut result = Vec::new();
259    for mut class in equivalences.iter().cloned() {
260        // Let's put the simplest expression at the beginning of `class`, because all the
261        // expressions will involve `class[0]`. For example, without sorting, we can get stuff like
262        // `Filter (#0 = 5) AND (#0 = #1)`.
263        // With sorting, this comes out as
264        // `Filter (#0 = 5) AND (#1 = 5)`.
265        // TODO: In the long term, we might want to call the entire `EquivalenceClasses::minimize`.
266        class.sort_by(EquivalenceClasses::mir_scalar_expr_complexity);
267        for expr in class[1..].iter() {
268            result.push(MirScalarExpr::CallVariadic {
269                func: VariadicFunc::Or,
270                exprs: vec![
271                    MirScalarExpr::CallBinary {
272                        func: BinaryFunc::Eq,
273                        expr1: Box::new(class[0].clone()),
274                        expr2: Box::new(expr.clone()),
275                    },
276                    MirScalarExpr::CallVariadic {
277                        func: VariadicFunc::And,
278                        exprs: vec![class[0].clone().call_is_null(), expr.clone().call_is_null()],
279                    },
280                ],
281            });
282        }
283    }
284    result
285}