mz_transform/fusion/join.rs
1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Fuses multiple `Join` operators into one `Join` operator.
11//!
12//! Multiway join planning relies on a broad view of the involved relations,
13//! and chains of binary joins can make this challenging to reason about.
14//! Collecting multiple joins together with their constraints improves
15//! our ability to plan these joins, and reason about other operators' motion
16//! around them.
17//!
18//! Also removes unit collections from joins, and joins with fewer than two inputs.
19//!
20//! Unit collections have no columns and a count of one, and a join with such
21//! a collection act as the identity operator on collections. Once removed,
22//! we may find joins with zero or one input, which can be further simplified.
23
24use std::collections::BTreeMap;
25
26use mz_expr::VariadicFunc;
27use mz_expr::visit::Visit;
28use mz_expr::{MapFilterProject, MirRelationExpr, MirScalarExpr};
29
30use crate::analysis::equivalences::EquivalenceClasses;
31use crate::canonicalize_mfp::CanonicalizeMfp;
32use crate::predicate_pushdown::PredicatePushdown;
33use crate::{TransformCtx, TransformError};
34
35/// Fuses multiple `Join` operators into one `Join` operator.
36///
37/// Removes unit collections from joins, and joins with fewer than two inputs.
38/// Filters on top of nested joins are lifted so the nested joins can be fused.
39#[derive(Debug)]
40pub struct Join;
41
42impl crate::Transform for Join {
43    fn name(&self) -> &'static str {
44        "JoinFusion"
45    }
46
47    #[mz_ore::instrument(
48        target = "optimizer",
49        level = "debug",
50        fields(path.segment = "join_fusion")
51    )]
52    fn actually_perform_transform(
53        &self,
54        relation: &mut MirRelationExpr,
55        _: &mut TransformCtx,
56    ) -> Result<(), TransformError> {
57        // We need to stick with post-order here because `action` only fuses a
58        // Join with its direct children. This means that we can only fuse a
59        // tree of Join nodes in a single pass if we work bottom-up.
60        let mut transformed = false;
61        relation.try_visit_mut_post(&mut |relation| {
62            transformed |= Self::action(relation)?;
63            Ok::<_, TransformError>(())
64        })?;
65        // If the action applied in the non-trivial case, run PredicatePushdown
66        // and CanonicalizeMfp in order to re-construct an equi-Join which would
67        // be de-constructed as a Filter + CrossJoin by the action application.
68        //
69        // TODO(database-issues#7728): This is a temporary solution which fixes the "Product
70        // limits" issue observed in a failed Nightly run when the PR was first
71        // tested (https://buildkite.com/materialize/nightly/builds/6670). We
72        // should re-evaluate if we need this ad-hoc re-normalization step when
73        // LiteralLifting is removed in favor of EquivalencePropagation.
74        if transformed {
75            PredicatePushdown::default().action(relation, &mut BTreeMap::new())?;
76            CanonicalizeMfp.action(relation)?
77        }
78        mz_repr::explain::trace_plan(&*relation);
79        Ok(())
80    }
81}
82
83impl Join {
84    /// Fuses multiple `Join` operators into one `Join` operator.
85    ///
86    /// Return Ok(true) iff the action manipulated the tree after detecting the
87    /// most general pattern.
88    pub fn action(relation: &mut MirRelationExpr) -> Result<bool, TransformError> {
89        if let MirRelationExpr::Join {
90            inputs,
91            equivalences,
92            ..
93        } = relation
94        {
95            // Local non-fusion tidying.
96            inputs.retain(|e| !e.is_constant_singleton());
97            if inputs.len() == 0 {
98                *relation =
99                    MirRelationExpr::constant(vec![vec![]], mz_repr::SqlRelationType::empty())
100                        .filter(unpack_equivalences(equivalences));
101                return Ok(false);
102            }
103            if inputs.len() == 1 {
104                *relation = inputs
105                    .pop()
106                    .unwrap()
107                    .filter(unpack_equivalences(equivalences));
108                return Ok(false);
109            }
110
111            // Bail early if no children are MFPs around a Join
112            if inputs.iter().any(|mut expr| {
113                let mut result = None;
114                while result.is_none() {
115                    match expr {
116                        MirRelationExpr::Map { input, .. }
117                        | MirRelationExpr::Filter { input, .. }
118                        | MirRelationExpr::Project { input, .. } => {
119                            expr = &**input;
120                        }
121                        MirRelationExpr::Join { .. } => {
122                            result = Some(true);
123                        }
124                        _ => {
125                            result = Some(false);
126                        }
127                    }
128                }
129                result.unwrap()
130            }) {
131                // Each input is either an MFP around a Join, or just an expression.
132                let children = inputs
133                    .iter()
134                    .map(|expr| {
135                        let (mfp, inner) = MapFilterProject::extract_from_expression(expr);
136                        if let MirRelationExpr::Join {
137                            inputs,
138                            equivalences,
139                            ..
140                        } = inner
141                        {
142                            Ok((mfp, (inputs, equivalences)))
143                        } else {
144                            Err((mfp.projection.len(), expr))
145                        }
146                    })
147                    .collect::<Vec<_>>();
148
149                // Our plan is to append all subjoin inputs, and non-join expressions.
150                // Each join will lift its MFP to act on the whole product (via arity).
151                // The final join will also be wrapped with `equivalences` as predicates.
152
153                let mut outer_arity = children
154                    .iter()
155                    .map(|child| match child {
156                        Ok((mfp, _)) => mfp.input_arity,
157                        Err((arity, _)) => *arity,
158                    })
159                    .sum();
160
161                // We will accumulate the lifted transformations here.
162                let mut outer_mfp = MapFilterProject::new(outer_arity);
163
164                let mut arity_so_far = 0;
165
166                let mut new_inputs = Vec::new();
167                for child in children.into_iter() {
168                    match child {
169                        Ok((mut mfp, (inputs, equivalences))) => {
170                            // Add the join inputs to the new join inputs.
171                            new_inputs.extend(inputs.iter().cloned());
172
173                            mfp.optimize();
174                            let (mut map, mut filter, mut project) = mfp.as_map_filter_project();
175                            filter.extend(unpack_equivalences(equivalences));
176                            // We need to rewrite column references in map and filter.
177                            // the applied map elements will be at the end, starting at `outer_arity`.
178                            for expr in map.iter_mut() {
179                                expr.visit_pre_mut(|e| {
180                                    if let MirScalarExpr::Column(c, _) = e {
181                                        if *c >= mfp.input_arity {
182                                            *c -= mfp.input_arity;
183                                            *c += outer_arity;
184                                        } else {
185                                            *c += arity_so_far;
186                                        }
187                                    }
188                                });
189                            }
190                            for expr in filter.iter_mut() {
191                                expr.visit_pre_mut(|e| {
192                                    if let MirScalarExpr::Column(c, _) = e {
193                                        if *c >= mfp.input_arity {
194                                            *c -= mfp.input_arity;
195                                            *c += outer_arity;
196                                        } else {
197                                            *c += arity_so_far;
198                                        }
199                                    }
200                                });
201                            }
202                            for c in project.iter_mut() {
203                                if *c >= mfp.input_arity {
204                                    *c -= mfp.input_arity;
205                                    *c += outer_arity;
206                                } else {
207                                    *c += arity_so_far;
208                                }
209                            }
210
211                            outer_mfp = outer_mfp.map(map.clone());
212                            outer_mfp = outer_mfp.filter(filter);
213                            let projection = (0..arity_so_far)
214                                .chain(project.clone())
215                                .chain(arity_so_far + mfp.input_arity..outer_arity)
216                                .collect::<Vec<_>>();
217                            outer_mfp = outer_mfp.project(projection);
218
219                            outer_arity += project.len();
220                            outer_arity -= mfp.input_arity;
221                            arity_so_far += project.len();
222                        }
223                        Err((arity, expr)) => {
224                            new_inputs.push((*expr).clone());
225                            arity_so_far += arity;
226                        }
227                    }
228                }
229
230                new_inputs.retain(|e| !e.is_constant_singleton());
231
232                outer_mfp = outer_mfp.filter(unpack_equivalences(equivalences));
233                outer_mfp.optimize();
234                let (map, filter, project) = outer_mfp.as_map_filter_project();
235
236                *relation = match new_inputs.len() {
237                    0 => MirRelationExpr::constant(vec![vec![]], mz_repr::SqlRelationType::empty()),
238                    1 => new_inputs.pop().unwrap(),
239                    _ => MirRelationExpr::join(new_inputs, Vec::new()),
240                }
241                .map(map)
242                .filter(filter)
243                .project(project);
244
245                return Ok(true);
246            }
247        }
248
249        Ok(false)
250    }
251}
252
253/// Unpacks multiple equivalence classes into conjuncts that should all be true, essentially
254/// turning join equivalences into a Filter.
255///
256/// Note that a join equivalence treats null equal to null, while an `=` in a Filter does not.
257/// This function is mindful of this.
258fn unpack_equivalences(equivalences: &Vec<Vec<MirScalarExpr>>) -> Vec<MirScalarExpr> {
259    let mut result = Vec::new();
260    for mut class in equivalences.iter().cloned() {
261        // Let's put the simplest expression at the beginning of `class`, because all the
262        // expressions will involve `class[0]`. For example, without sorting, we can get stuff like
263        // `Filter (#0 = 5) AND (#0 = #1)`.
264        // With sorting, this comes out as
265        // `Filter (#0 = 5) AND (#1 = 5)`.
266        // TODO: In the long term, we might want to call the entire `EquivalenceClasses::minimize`.
267        class.sort_by(EquivalenceClasses::mir_scalar_expr_complexity);
268        for expr in class[1..].iter() {
269            result.push(MirScalarExpr::CallVariadic {
270                func: VariadicFunc::Or,
271                exprs: vec![
272                    MirScalarExpr::CallBinary {
273                        func: mz_expr::func::Eq.into(),
274                        expr1: Box::new(class[0].clone()),
275                        expr2: Box::new(expr.clone()),
276                    },
277                    MirScalarExpr::CallVariadic {
278                        func: VariadicFunc::And,
279                        exprs: vec![class[0].clone().call_is_null(), expr.clone().call_is_null()],
280                    },
281                ],
282            });
283        }
284    }
285    result
286}