Skip to main content

mz_transform/
normalize_lets.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Normalize the structure of `Let` and `LetRec` operators in expressions.
11//!
12//! Normalization happens in the context of "scopes", corresponding to
13//! 1. the expression's root and 2. each instance of a `LetRec` AST node.
14//!
15//! Within each scope,
16//! 1. Each expression is normalized to have all `Let` nodes at the root
17//! of the expression, in order of identifier.
18//! 2. Each expression assigns a contiguous block of identifiers.
19//!
20//! The transform may remove some `Let` and `Get` operators, and does not
21//! introduce any new operators.
22//!
23//! The module also publishes the function `renumber_bindings` which can
24//! be used to renumber bindings in an expression starting from a provided
25//! `IdGen`, which is used to prepare distinct expressions for inlining.
26
27use mz_expr::{MirRelationExpr, visit::Visit};
28use mz_ore::assert_none;
29use mz_ore::{id_gen::IdGen, stack::RecursionLimitError};
30use mz_repr::optimize::OptimizerFeatures;
31
32use crate::{TransformCtx, catch_unwind_optimize};
33
34pub use renumbering::renumber_bindings;
35
36/// Normalize `Let` and `LetRec` structure.
37pub fn normalize_lets(
38    expr: &mut MirRelationExpr,
39    features: &OptimizerFeatures,
40) -> Result<(), crate::TransformError> {
41    catch_unwind_optimize(|| NormalizeLets::new(false).action(expr, features))
42}
43
44/// Install replace certain `Get` operators with their `Let` value.
45#[derive(Debug)]
46pub struct NormalizeLets {
47    /// If `true`, inline MFPs around a Get.
48    ///
49    /// We want this value to be true for the NormalizeLets call that comes right
50    /// before [crate::join_implementation::JoinImplementation] runs because
51    /// - JoinImplementation cannot lift MFPs through a Let.
52    /// - JoinImplementation can't extract FilterCharacteristics through a Let.
53    ///
54    /// Generally, though, we prefer to be more conservative in our inlining in
55    /// order to be able to better detect CSEs.
56    pub inline_mfp: bool,
57}
58
59impl NormalizeLets {
60    /// Construct a new [`NormalizeLets`] instance with the given `inline_mfp`.
61    pub fn new(inline_mfp: bool) -> NormalizeLets {
62        NormalizeLets { inline_mfp }
63    }
64}
65
66impl crate::Transform for NormalizeLets {
67    fn name(&self) -> &'static str {
68        "NormalizeLets"
69    }
70
71    #[mz_ore::instrument(
72        target = "optimizer",
73        level = "debug",
74        fields(path.segment = "normalize_lets")
75    )]
76    fn actually_perform_transform(
77        &self,
78        relation: &mut MirRelationExpr,
79        ctx: &mut TransformCtx,
80    ) -> Result<(), crate::TransformError> {
81        let result = self.action(relation, ctx.features);
82        mz_repr::explain::trace_plan(&*relation);
83        result
84    }
85}
86
87impl NormalizeLets {
88    /// Normalize `Let` and `LetRec` bindings in `relation`.
89    ///
90    /// Mechanically, `action` first renumbers all bindings, erroring if any shadowing is encountered.
91    /// It then promotes all `Let` and `LetRec` expressions to the roots of their expressions, fusing
92    /// `Let` bindings into containing `LetRec` bindings, but leaving stacked `LetRec` bindings unfused to each
93    /// other (for reasons of correctness). It then considers potential inlining in each `LetRec` scope.
94    /// Lastly, it refreshes the types of each `Get` operator, erroring if any scalar types have changed
95    /// but updating nullability and keys.
96    ///
97    /// We then perform a final renumbering.
98    pub fn action(
99        &self,
100        relation: &mut MirRelationExpr,
101        features: &OptimizerFeatures,
102    ) -> Result<(), crate::TransformError> {
103        // Record whether the relation was initially recursive, to confirm that we do not introduce
104        // recursion to a non-recursive expression.
105        let was_recursive = relation.is_recursive();
106
107        // Renumber all bindings to ensure that identifier order matches binding order.
108        // In particular, as we use `BTreeMap` for binding order, we want to ensure that
109        // 1. Bindings within a `LetRec` are assigned increasing identifiers, and
110        // 2. Bindings across `LetRec`s are assigned identifiers in "visibility order", corresponding to an
111        // in-order traversal.
112        // TODO: More can and perhaps should be said about "visibility order" and how let promotion is correct.
113        renumbering::renumber_bindings(relation, &mut IdGen::default())?;
114
115        // Promote all `Let` and `LetRec` AST nodes to the roots.
116        // After this, all non-`LetRec` nodes contain no further `Let` or `LetRec` nodes,
117        // placing all `LetRec` nodes around the root, if not always in a single AST node.
118        let_motion::promote_let_rec(relation);
119        let_motion::assert_no_lets(relation);
120        let_motion::assert_letrec_major(relation);
121
122        // Inlining may violate letrec-major form.
123        inlining::inline_lets(relation, self.inline_mfp)?;
124
125        // Return to letrec-major form to refresh types.
126        let_motion::promote_let_rec(relation);
127        support::refresh_types(relation, features)?;
128
129        // Renumber bindings for good measure.
130        // Ideally we could skip when `action` is a no-op, but hard to thread that through at the moment.
131        renumbering::renumber_bindings(relation, &mut IdGen::default())?;
132
133        // A final bottom-up traversal to normalize the shape of nested LetRec blocks
134        relation.try_visit_mut_post(&mut |relation| -> Result<(), RecursionLimitError> {
135            // Move a non-recursive suffix of bindings from the end of the LetRec
136            // to the LetRec body.
137            // This is unsafe when applied to expressions which contain `ArrangeBy`,
138            // as if the extracted suffixes reference arrangements they will not be
139            // able to access those arrangements from outside the `LetRec` scope.
140            // It happens to work at the moment, so we don't touch it but should fix.
141            let bindings = let_motion::harvest_nonrec_suffix(relation)?;
142            if let MirRelationExpr::LetRec {
143                ids: _,
144                values: _,
145                limits: _,
146                body,
147            } = relation
148            {
149                for (id, value) in bindings.into_iter().rev() {
150                    **body = MirRelationExpr::Let {
151                        id,
152                        value: Box::new(value),
153                        body: Box::new(body.take_dangerous()),
154                    };
155                }
156            } else {
157                for (id, value) in bindings.into_iter().rev() {
158                    *relation = MirRelationExpr::Let {
159                        id,
160                        value: Box::new(value),
161                        body: Box::new(relation.take_dangerous()),
162                    };
163                }
164            }
165
166            // Extract `Let` prefixes from `LetRec`, to reveal their non-recursive nature.
167            // This assists with hoisting e.g. arrangements out of `LetRec` blocks, a thing
168            // we don't promise to do, but it can be helpful to do. This also exposes more
169            // AST nodes to non-`LetRec` analyses, which don't always have parity with `LetRec`.
170            let bindings = let_motion::harvest_non_recursive(relation);
171            for (id, (value, max_iter)) in bindings.into_iter().rev() {
172                assert_none!(max_iter);
173                *relation = MirRelationExpr::Let {
174                    id,
175                    value: Box::new(value),
176                    body: Box::new(relation.take_dangerous()),
177                };
178            }
179
180            Ok(())
181        })?;
182
183        if !was_recursive && relation.is_recursive() {
184            Err(crate::TransformError::Internal(
185                "NormalizeLets introduced LetRec to a LetRec-free expression".to_string(),
186            ))?;
187        }
188
189        Ok(())
190    }
191}
192
193// Support methods that are unlikely to be useful to other modules.
194mod support {
195
196    use std::collections::BTreeMap;
197
198    use itertools::Itertools;
199
200    use mz_expr::{Id, LetRecLimit, LocalId, MirRelationExpr};
201    use mz_repr::optimize::OptimizerFeatures;
202
203    pub(super) fn replace_bindings_from_map(
204        map: BTreeMap<LocalId, (MirRelationExpr, Option<LetRecLimit>)>,
205        ids: &mut Vec<LocalId>,
206        values: &mut Vec<MirRelationExpr>,
207        limits: &mut Vec<Option<LetRecLimit>>,
208    ) {
209        let (new_ids, new_values, new_limits) = map_to_3vecs(map);
210        *ids = new_ids;
211        *values = new_values;
212        *limits = new_limits;
213    }
214
215    pub(super) fn map_to_3vecs(
216        map: BTreeMap<LocalId, (MirRelationExpr, Option<LetRecLimit>)>,
217    ) -> (Vec<LocalId>, Vec<MirRelationExpr>, Vec<Option<LetRecLimit>>) {
218        let (new_ids, new_values_and_limits): (Vec<_>, Vec<_>) = map.into_iter().unzip();
219        let (new_values, new_limits) = new_values_and_limits.into_iter().unzip();
220        (new_ids, new_values, new_limits)
221    }
222
223    /// Logic mapped across each use of a `LocalId`.
224    pub(super) fn for_local_id<F>(expr: &MirRelationExpr, mut logic: F)
225    where
226        F: FnMut(LocalId),
227    {
228        expr.visit_pre(|expr| {
229            if let MirRelationExpr::Get {
230                id: Id::Local(i), ..
231            } = expr
232            {
233                logic(*i);
234            }
235        });
236    }
237
238    /// Populates `counts` with the number of uses of each local identifier in `expr`.
239    pub(super) fn count_local_id_uses(
240        expr: &MirRelationExpr,
241        counts: &mut std::collections::BTreeMap<LocalId, usize>,
242    ) {
243        for_local_id(expr, |i| *counts.entry(i).or_insert(0) += 1)
244    }
245
246    /// Visit `LetRec` stages and determine and update type information for `Get` nodes.
247    ///
248    /// This method errors if the scalar type information has changed (number of columns, or types).
249    /// It only refreshes the nullability and unique key information. As this information can regress,
250    /// we do not error if the type weakens, even though that may be something we want to look into.
251    ///
252    /// The method relies on the `analysis::{UniqueKeys, SqlRelationType}` analyses to improve its type
253    /// information for `LetRec` stages.
254    pub(super) fn refresh_types(
255        expr: &mut MirRelationExpr,
256        features: &OptimizerFeatures,
257    ) -> Result<(), crate::TransformError> {
258        // Assemble type information once for the whole expression.
259        use crate::analysis::{DerivedBuilder, SqlRelationType, UniqueKeys};
260        let mut builder = DerivedBuilder::new(features);
261        builder.require(SqlRelationType);
262        builder.require(UniqueKeys);
263        let derived = builder.visit(expr);
264        let derived_view = derived.as_view();
265
266        // Collect id -> type mappings.
267        let mut types = BTreeMap::new();
268        let mut todo = vec![(&*expr, derived_view)];
269        while let Some((expr, view)) = todo.pop() {
270            let ids = match expr {
271                MirRelationExpr::Let { id, .. } => std::slice::from_ref(id),
272                MirRelationExpr::LetRec { ids, .. } => ids,
273                _ => &[],
274            };
275            if !ids.is_empty() {
276                // The `skip(1)` skips the `body` child, and is followed by binding children.
277                for (id, view) in ids.iter().rev().zip_eq(view.children_rev().skip(1)) {
278                    let cols = view
279                        .value::<SqlRelationType>()
280                        .expect("SqlRelationType required")
281                        .clone()
282                        .expect("Expression not well typed");
283                    let keys = view
284                        .value::<UniqueKeys>()
285                        .expect("UniqueKeys required")
286                        .clone();
287                    types.insert(*id, mz_repr::SqlRelationType::new(cols).with_keys(keys));
288                }
289            }
290            todo.extend(expr.children().rev().zip_eq(view.children_rev()));
291        }
292
293        // Install the new types in each `Get`.
294        let mut todo = vec![&mut *expr];
295        while let Some(expr) = todo.pop() {
296            if let MirRelationExpr::Get {
297                id: Id::Local(i),
298                typ,
299                ..
300            } = expr
301            {
302                if let Some(new_type) = types.get(i) {
303                    // Assert that the column length has not changed.
304                    if !new_type.column_types.len() == typ.column_types.len() {
305                        Err(crate::TransformError::Internal(format!(
306                            "column lengths do not match: {:?} v {:?}",
307                            new_type.column_types, typ.column_types
308                        )))?;
309                    }
310                    // Assert that the column types have not changed.
311                    if !new_type
312                        .column_types
313                        .iter()
314                        .zip_eq(typ.column_types.iter())
315                        .all(|(t1, t2)| {
316                            t1.scalar_type
317                                .base_eq_or_repr_eq_for_assertion(&t2.scalar_type)
318                        })
319                    {
320                        Err(crate::TransformError::Internal(format!(
321                            "scalar types do not match: {:?} v {:?}",
322                            new_type.column_types, typ.column_types
323                        )))?;
324                    }
325
326                    typ.clone_from(new_type);
327                } else {
328                    panic!("Type not found for: {:?}", i);
329                }
330            }
331            todo.extend(expr.children_mut());
332        }
333        Ok(())
334    }
335}
336
337mod let_motion {
338
339    use std::collections::{BTreeMap, BTreeSet};
340
341    use itertools::Itertools;
342    use mz_expr::{LetRecLimit, LocalId, MirRelationExpr};
343    use mz_ore::stack::RecursionLimitError;
344
345    use crate::normalize_lets::support::replace_bindings_from_map;
346
347    /// Promotes all `Let` and `LetRec` nodes to the roots of their expressions.
348    ///
349    /// We cannot (without further reasoning) fuse stacked `LetRec` stages, and instead we just promote
350    /// `LetRec` to the roots of their expressions (e.g. as children of another `LetRec` stage).
351    pub(crate) fn promote_let_rec(expr: &mut MirRelationExpr) {
352        // First, promote all `LetRec` nodes above all other nodes.
353        let mut worklist = vec![&mut *expr];
354        while let Some(mut expr) = worklist.pop() {
355            hoist_bindings(expr);
356            while let MirRelationExpr::LetRec { values, body, .. } = expr {
357                worklist.extend(values.iter_mut().rev());
358                expr = body;
359            }
360        }
361
362        // Harvest any potential `Let` nodes, via a post-order traversal.
363        post_order_harvest_lets(expr);
364    }
365
366    /// A stand in for the types of bindings we might encounter.
367    ///
368    /// As we dissolve various `Let` and `LetRec` expressions, a `Binding` will carry
369    /// the relevant information as we hoist it to the root of the expression.
370    enum Binding {
371        // Binding resulting from a `Let` expression.
372        Let(LocalId, MirRelationExpr),
373        // Bindings resulting from a `LetRec` expression.
374        LetRec(Vec<(LocalId, MirRelationExpr, Option<LetRecLimit>)>),
375    }
376
377    /// Hoist all exposed bindings to the root of the expression.
378    ///
379    /// A binding is "exposed" if the path from the root does not cross a LetRec binding.
380    /// After the call, the expression should be a linear sequence of bindings, where each
381    /// `Let` binding is of a let-free expression. There may be `LetRec` expressions in the
382    /// sequence, and their bindings will have hoisted bindings to their root, but not out
383    /// of the binding.
384    fn hoist_bindings(expr: &mut MirRelationExpr) {
385        // Bindings we have extracted but not fully processed.
386        let mut worklist = Vec::new();
387        // Bindings we have extracted and then fully processed.
388        let mut finished = Vec::new();
389
390        extract_bindings(expr, &mut worklist);
391        while let Some(mut bind) = worklist.pop() {
392            match &mut bind {
393                Binding::Let(_id, value) => {
394                    extract_bindings(value, &mut worklist);
395                }
396                Binding::LetRec(_binds) => {
397                    // nothing to do here; we cannot hoist letrec bindings and refine
398                    // them in an outer loop.
399                }
400            }
401            finished.push(bind);
402        }
403
404        // The worklist is empty and finished should contain only LetRec bindings and Let
405        // bindings with let-free expressions bound. We need to re-assemble them now in
406        // the correct order. The identifiers are "sequential", so we should be able to
407        // sort by them, with some care.
408
409        // We only extract non-empty letrec bindings, so it is safe to peek at the first.
410        finished.sort_by_key(|b| match b {
411            Binding::Let(id, _) => *id,
412            Binding::LetRec(binds) => binds[0].0,
413        });
414
415        // To match historical behavior we fuse let bindings into adjacent letrec bindings.
416        // We could alternately make each a singleton letrec binding (just, non-recursive).
417        // We don't yet have a strong opinion on which is most helpful and least harmful.
418        // In the absence of any letrec bindings, we form one to house the let bindings.
419        let mut ids = Vec::new();
420        let mut values = Vec::new();
421        let mut limits = Vec::new();
422        let mut compact = Vec::new();
423        for bind in finished {
424            match bind {
425                Binding::Let(id, value) => {
426                    ids.push(id);
427                    values.push(value);
428                    limits.push(None);
429                }
430                Binding::LetRec(binds) => {
431                    for (id, value, limit) in binds {
432                        ids.push(id);
433                        values.push(value);
434                        limits.push(limit);
435                    }
436                    compact.push((ids, values, limits));
437                    ids = Vec::new();
438                    values = Vec::new();
439                    limits = Vec::new();
440                }
441            }
442        }
443
444        // Remaining bindings can either be fused to the prior letrec, or put in their own.
445        if let Some((last_ids, last_vals, last_lims)) = compact.last_mut() {
446            last_ids.extend(ids);
447            last_vals.extend(values);
448            last_lims.extend(limits);
449        } else if !ids.is_empty() {
450            compact.push((ids, values, limits));
451        }
452
453        while let Some((ids, values, limits)) = compact.pop() {
454            *expr = MirRelationExpr::LetRec {
455                ids,
456                values,
457                limits,
458                body: Box::new(expr.take_dangerous()),
459            };
460        }
461    }
462
463    /// Extracts exposed bindings into `bindings`.
464    ///
465    /// After this call `expr` will contain no let or letrec bindings, though the bindings
466    /// it introduces to `bindings` may themselves contain such bindings (and they should
467    /// be further processed if the goal is to maximally extract let bindings).
468    fn extract_bindings(expr: &mut MirRelationExpr, bindings: &mut Vec<Binding>) {
469        let mut todo = vec![expr];
470        while let Some(expr) = todo.pop() {
471            match expr {
472                MirRelationExpr::Let { id, value, body } => {
473                    bindings.push(Binding::Let(*id, value.take_dangerous()));
474                    *expr = body.take_dangerous();
475                    todo.push(expr);
476                }
477                MirRelationExpr::LetRec {
478                    ids,
479                    values,
480                    limits,
481                    body,
482                } => {
483                    use itertools::Itertools;
484                    let binds: Vec<_> = ids
485                        .drain(..)
486                        .zip_eq(values.drain(..))
487                        .zip_eq(limits.drain(..))
488                        .map(|((i, v), l)| (i, v, l))
489                        .collect();
490                    if !binds.is_empty() {
491                        bindings.push(Binding::LetRec(binds));
492                    }
493                    *expr = body.take_dangerous();
494                    todo.push(expr);
495                }
496                _ => {
497                    todo.extend(expr.children_mut());
498                }
499            }
500        }
501    }
502
503    /// Performs a post-order traversal of the `LetRec` nodes at the root of an expression.
504    ///
505    /// The traversal is only of the `LetRec` nodes, for which fear of stack exhaustion is nominal.
506    fn post_order_harvest_lets(expr: &mut MirRelationExpr) {
507        if let MirRelationExpr::LetRec {
508            ids,
509            values,
510            limits,
511            body,
512        } = expr
513        {
514            // Only recursively descend through `LetRec` stages.
515            for value in values.iter_mut() {
516                post_order_harvest_lets(value);
517            }
518
519            let mut bindings = BTreeMap::new();
520            for ((id, mut value), max_iter) in ids
521                .drain(..)
522                .zip_eq(values.drain(..))
523                .zip_eq(limits.drain(..))
524            {
525                bindings.extend(harvest_non_recursive(&mut value));
526                bindings.insert(id, (value, max_iter));
527            }
528            bindings.extend(harvest_non_recursive(body));
529            replace_bindings_from_map(bindings, ids, values, limits);
530        }
531    }
532
533    /// Harvest any safe-to-lift non-recursive bindings from a `LetRec`
534    /// expression.
535    ///
536    /// At the moment, we reason that a binding can be lifted without changing
537    /// the output if both:
538    /// 1. It references no other non-lifted binding bound in `expr`,
539    /// 2. It is referenced by no prior non-lifted binding in `expr`.
540    ///
541    /// The rationale is that (1) ensures that the binding's value does not
542    /// change across iterations, and that (2) ensures that all observations of
543    /// the binding are after it assumes its first value, rather than when it
544    /// could be empty.
545    pub(crate) fn harvest_non_recursive(
546        expr: &mut MirRelationExpr,
547    ) -> BTreeMap<LocalId, (MirRelationExpr, Option<LetRecLimit>)> {
548        if let MirRelationExpr::LetRec {
549            ids,
550            values,
551            limits,
552            body,
553        } = expr
554        {
555            // Bindings to lift.
556            let mut lifted = BTreeMap::<LocalId, (MirRelationExpr, Option<LetRecLimit>)>::new();
557            // Bindings to retain.
558            let mut retained = BTreeMap::<LocalId, (MirRelationExpr, Option<LetRecLimit>)>::new();
559
560            // All remaining LocalIds bound by the enclosing LetRec.
561            let mut id_set = ids.iter().cloned().collect::<BTreeSet<LocalId>>();
562            // All LocalIds referenced up to (including) the current binding.
563            let mut cannot = BTreeSet::<LocalId>::new();
564            // The reference count of the current bindings.
565            let mut refcnt = BTreeMap::<LocalId, usize>::new();
566
567            for ((id, value), max_iter) in ids
568                .drain(..)
569                .zip_eq(values.drain(..))
570                .zip_eq(limits.drain(..))
571            {
572                refcnt.clear();
573                super::support::count_local_id_uses(&value, &mut refcnt);
574
575                // LocalIds that have already been referenced cannot be lifted.
576                cannot.extend(refcnt.keys().cloned());
577
578                // - The first conjunct excludes bindings that have already been
579                //   referenced.
580                // - The second conjunct excludes bindings that reference a
581                //   LocalId that either defined later or is a known retained.
582                if !cannot.contains(&id) && !refcnt.keys().any(|i| id_set.contains(i)) {
583                    lifted.insert(id, (value, None)); // Non-recursive bindings don't need a limit
584                    id_set.remove(&id);
585                } else {
586                    retained.insert(id, (value, max_iter));
587                }
588            }
589
590            replace_bindings_from_map(retained, ids, values, limits);
591            if values.is_empty() {
592                *expr = body.take_dangerous();
593            }
594
595            lifted
596        } else {
597            BTreeMap::new()
598        }
599    }
600
601    /// Harvest any safe-to-lower non-recursive suffix of binding from a
602    /// `LetRec` expression.
603    pub(crate) fn harvest_nonrec_suffix(
604        expr: &mut MirRelationExpr,
605    ) -> Result<BTreeMap<LocalId, MirRelationExpr>, RecursionLimitError> {
606        if let MirRelationExpr::LetRec {
607            ids,
608            values,
609            limits,
610            body,
611        } = expr
612        {
613            // Bindings to lower.
614            let mut lowered = BTreeMap::<LocalId, MirRelationExpr>::new();
615
616            let rec_ids = MirRelationExpr::recursive_ids(ids, values);
617
618            while ids.last().map(|id| !rec_ids.contains(id)).unwrap_or(false) {
619                let id = ids.pop().expect("non-empty ids");
620                let value = values.pop().expect("non-empty values");
621                let _limit = limits.pop().expect("non-empty limits");
622
623                lowered.insert(id, value); // Non-recursive bindings don't need a limit
624            }
625
626            if values.is_empty() {
627                *expr = body.take_dangerous();
628            }
629
630            Ok(lowered)
631        } else {
632            Ok(BTreeMap::new())
633        }
634    }
635
636    pub(crate) fn assert_no_lets(expr: &MirRelationExpr) {
637        expr.visit_pre(|expr| {
638            assert!(!matches!(expr, MirRelationExpr::Let { .. }));
639        });
640    }
641
642    /// Asserts that `expr` in "LetRec-major" form.
643    ///
644    /// This means `expr` is either `LetRec`-free, or a `LetRec` whose values and body are `LetRec`-major.
645    pub(crate) fn assert_letrec_major(expr: &MirRelationExpr) {
646        let mut todo = vec![expr];
647        while let Some(expr) = todo.pop() {
648            match expr {
649                MirRelationExpr::LetRec {
650                    ids: _,
651                    values,
652                    limits: _,
653                    body,
654                } => {
655                    todo.extend(values.iter());
656                    todo.push(body);
657                }
658                _ => {
659                    expr.visit_pre(|expr| {
660                        assert!(!matches!(expr, MirRelationExpr::LetRec { .. }));
661                    });
662                }
663            }
664        }
665    }
666}
667
668mod inlining {
669
670    use std::collections::BTreeMap;
671
672    use itertools::Itertools;
673    use mz_expr::{Id, LetRecLimit, LocalId, MirRelationExpr};
674
675    use crate::normalize_lets::support::replace_bindings_from_map;
676
677    pub(super) fn inline_lets(
678        expr: &mut MirRelationExpr,
679        inline_mfp: bool,
680    ) -> Result<(), crate::TransformError> {
681        let mut worklist = vec![&mut *expr];
682        while let Some(expr) = worklist.pop() {
683            inline_lets_core(expr, inline_mfp)?;
684            // We descend only into `LetRec` nodes, because `promote_let_rec` ensured that all
685            // `LetRec` nodes are clustered near the root. This means that we can get to all the
686            // `LetRec` nodes by just descending into `LetRec` nodes, as there can't be any other
687            // nodes between them.
688            if let MirRelationExpr::LetRec {
689                ids: _,
690                values,
691                limits: _,
692                body,
693            } = expr
694            {
695                worklist.extend(values);
696                worklist.push(body);
697            }
698        }
699        Ok(())
700    }
701
702    /// Considers inlining actions to perform for a sequence of bindings and a
703    /// following body.
704    ///
705    /// A let binding may be inlined only in subsequent bindings or in the body;
706    /// other bindings should not "immediately" observe the binding, as that
707    /// would be a change to the semantics of `LetRec`. For example, it would
708    /// not be correct to replace `C` with `A` in the definition of `B` here:
709    /// ```ignore
710    /// let A = ...;
711    /// let B = A - C;
712    /// let C = A;
713    /// ```
714    /// The explanation is that `B` should always be the difference between the
715    /// current and previous `A`, and that the substitution of `C` would instead
716    /// make it always zero, changing its definition.
717    ///
718    /// Here a let binding is proposed for inlining if any of the following is true:
719    ///  1. It has a single reference across all bindings and the body.
720    ///  2. It is a "sufficiently simple" `Get`, determined in part by the
721    ///     `inline_mfp` argument.
722    ///
723    /// We don't need extra checks for `limits`, because
724    ///  - `limits` is only relevant when a binding is directly used through a back edge (because
725    ///    that is where the rendering puts the `limits` check);
726    ///  - when a binding is directly used through a back edge, it can't be inlined anyway.
727    ///  - Also note that if a `LetRec` completely disappears at the end of `inline_lets_core`, then
728    ///    there was no recursion in it.
729    ///
730    /// The case of `Constant` binding is handled here (as opposed to
731    /// `FoldConstants`) in a somewhat limited manner (see database-issues#5346). Although a
732    /// bit weird, constants should also not be inlined into prior bindings as
733    /// this does change the behavior from one where the collection is initially
734    /// empty to one where it is always the constant.
735    ///
736    /// Having inlined bindings, many of them may now be dead (with no
737    /// transitive references from `body`). These can now be removed. They may
738    /// not be exactly those bindings that were inlineable, as we may not always
739    /// be able to apply inlining due to ordering (we cannot inline a binding
740    /// into one that is not strictly later).
741    pub(super) fn inline_lets_core(
742        expr: &mut MirRelationExpr,
743        inline_mfp: bool,
744    ) -> Result<(), crate::TransformError> {
745        if let MirRelationExpr::LetRec {
746            ids,
747            values,
748            limits,
749            body,
750        } = expr
751        {
752            // Count the number of uses of each local id across all expressions.
753            let mut counts = BTreeMap::new();
754            for value in values.iter() {
755                super::support::count_local_id_uses(value, &mut counts);
756            }
757            super::support::count_local_id_uses(body, &mut counts);
758
759            // Each binding can reach one of three positions on its inlineability:
760            //  1. The binding is used once and is available to be directly taken.
761            //  2. The binding is simple enough that it can just be cloned.
762            //  3. The binding is not available for inlining.
763            let mut inline_offers = BTreeMap::new();
764
765            // Each binding may require the expiration of prior inlining offers.
766            // This occurs when an inlined body references the prior iterate of a binding,
767            // and inlining it would change the meaning to be the current iterate.
768            // Roughly, all inlining offers expire just after the binding of the least
769            // identifier they contain that is greater than the bound identifier itself.
770            let mut expire_offers = BTreeMap::new();
771            let mut expired_offers = Vec::new();
772
773            // For each binding, inline `Get`s and then determine if *it* should be inlined.
774            // It is important that we do the substitution in-order and before reasoning
775            // about the inlineability of each binding, to ensure that our conclusion about
776            // the inlineability of a binding stays put. Specifically,
777            //   1. by going in order no substitution will increase the `Get`-count of an
778            //      identifier beyond one, as all in values with strictly greater identifiers.
779            //   2. by performing the substitution before reasoning, the structure of the value
780            //      as it would be substituted is fixed.
781            for ((id, mut expr), max_iter) in ids
782                .drain(..)
783                .zip_eq(values.drain(..))
784                .zip_eq(limits.drain(..))
785            {
786                // Substitute any appropriate prior let bindings.
787                inline_lets_helper(&mut expr, &mut inline_offers)?;
788
789                // Determine the first `id'` at which any inlining offer must expire.
790                // An inlining offer expires because it references an `id'` that is not yet bound,
791                // indicating a reference to the *prior* iterate of that identifier. Inlining the
792                // expression once `id'` becomes bound would advance the reference to be the
793                // *current* iterate of the identifier.
794                MirRelationExpr::collect_expirations(id, &expr, &mut expire_offers);
795
796                // Gets for `id` only occur in later expressions, so this should still be correct.
797                let num_gets = counts.get(&id).map(|x| *x).unwrap_or(0);
798                // Counts of zero or one lead to substitution; otherwise certain simple structures
799                // are cloned in to `Get` operators, and all others emitted as `Let` bindings.
800                if num_gets == 0 {
801                } else if num_gets == 1 {
802                    inline_offers.insert(id, InlineOffer::Take(Some(expr), max_iter));
803                } else {
804                    let clone_binding = {
805                        let stripped_value = if inline_mfp {
806                            mz_expr::MapFilterProject::extract_non_errors_from_expr(&expr).1
807                        } else {
808                            &expr
809                        };
810                        match stripped_value {
811                            MirRelationExpr::Get { .. } | MirRelationExpr::Constant { .. } => true,
812                            _ => false,
813                        }
814                    };
815
816                    if clone_binding {
817                        inline_offers.insert(id, InlineOffer::Clone(expr, max_iter));
818                    } else {
819                        inline_offers.insert(id, InlineOffer::Unavailable(expr, max_iter));
820                    }
821                }
822
823                // We must now discard any offers that reference `id`, as it is no longer correct
824                // to inline such an offer as it would have access to this iteration's binding of
825                // `id` rather than the prior iteration's binding of `id`.
826                expired_offers.extend(MirRelationExpr::do_expirations(
827                    id,
828                    &mut expire_offers,
829                    &mut inline_offers,
830                ));
831            }
832            // Complete the inlining in `body`.
833            inline_lets_helper(body, &mut inline_offers)?;
834
835            // Re-introduce expired offers for the subsequent logic that expects to see them all.
836            for (id, offer) in expired_offers.into_iter() {
837                inline_offers.insert(id, offer);
838            }
839
840            // We may now be able to discard some of `inline_offer` based on the remaining pattern of `Get` expressions.
841            // Starting from `body` and working backwards, we can activate bindings that are still required because we
842            // observe `Get` expressions referencing them. Any bindings not so identified can be dropped (including any
843            // that may be part of a cycle not reachable from `body`).
844            let mut let_bindings = BTreeMap::new();
845            let mut todo = Vec::new();
846            super::support::for_local_id(body, |id| todo.push(id));
847            while let Some(id) = todo.pop() {
848                if let Some(offer) = inline_offers.remove(&id) {
849                    let (value, max_iter) = match offer {
850                        InlineOffer::Take(value, max_iter) => (
851                            value.ok_or_else(|| {
852                                crate::TransformError::Internal(
853                                    "Needed value already taken".to_string(),
854                                )
855                            })?,
856                            max_iter,
857                        ),
858                        InlineOffer::Clone(value, max_iter) => (value, max_iter),
859                        InlineOffer::Unavailable(value, max_iter) => (value, max_iter),
860                    };
861                    super::support::for_local_id(&value, |id| todo.push(id));
862                    let_bindings.insert(id, (value, max_iter));
863                }
864            }
865
866            // If bindings remain we update the `LetRec`, otherwise we remove it.
867            if !let_bindings.is_empty() {
868                replace_bindings_from_map(let_bindings, ids, values, limits);
869            } else {
870                *expr = body.take_dangerous();
871            }
872        }
873        Ok(())
874    }
875
876    /// Possible states of let binding inlineability.
877    enum InlineOffer {
878        /// There is a unique reference to this value and given the option it should take this expression.
879        Take(Option<MirRelationExpr>, Option<LetRecLimit>),
880        /// Any reference to this value should clone this expression.
881        Clone(MirRelationExpr, Option<LetRecLimit>),
882        /// Any reference to this value should do no inlining of it.
883        Unavailable(MirRelationExpr, Option<LetRecLimit>),
884    }
885
886    /// Substitute `Get{id}` expressions for any proposed expressions.
887    ///
888    /// The proposed expressions can be proposed either to be taken or cloned.
889    fn inline_lets_helper(
890        expr: &mut MirRelationExpr,
891        inline_offer: &mut BTreeMap<LocalId, InlineOffer>,
892    ) -> Result<(), crate::TransformError> {
893        let mut worklist = vec![expr];
894        while let Some(expr) = worklist.pop() {
895            if let MirRelationExpr::Get {
896                id: Id::Local(id), ..
897            } = expr
898            {
899                if let Some(offer) = inline_offer.get_mut(id) {
900                    // It is important that we *not* continue to iterate
901                    // on the contents of `offer`, which has already been
902                    // maximally inlined. If we did, we could mis-inline
903                    // bindings into bodies that precede them, which would
904                    // change the semantics of the expression.
905                    match offer {
906                        InlineOffer::Take(value, _max_iter) => {
907                            *expr = value.take().ok_or_else(|| {
908                                crate::TransformError::Internal(format!(
909                                    "Value already taken for {:?}",
910                                    id
911                                ))
912                            })?;
913                        }
914                        InlineOffer::Clone(value, _max_iter) => {
915                            *expr = value.clone();
916                        }
917                        InlineOffer::Unavailable(_, _) => {
918                            // Do nothing.
919                        }
920                    }
921                } else {
922                    // Presumably a reference to an outer scope.
923                }
924            } else {
925                worklist.extend(expr.children_mut().rev());
926            }
927        }
928        Ok(())
929    }
930}
931
932mod renumbering {
933
934    use std::collections::BTreeMap;
935
936    use itertools::Itertools;
937    use mz_expr::{Id, LocalId, MirRelationExpr};
938    use mz_ore::id_gen::IdGen;
939
940    /// Re-assign an identifier to each `Let`.
941    ///
942    /// Under the assumption that `id_gen` produces identifiers in order, this process
943    /// maintains in-orderness of `LetRec` identifiers.
944    pub fn renumber_bindings(
945        relation: &mut MirRelationExpr,
946        id_gen: &mut IdGen,
947    ) -> Result<(), crate::TransformError> {
948        let mut renaming = BTreeMap::new();
949        determine(&*relation, &mut renaming, id_gen)?;
950        implement(relation, &renaming)?;
951        Ok(())
952    }
953
954    /// Performs an in-order traversal of the AST, assigning identifiers as it goes.
955    fn determine(
956        relation: &MirRelationExpr,
957        remap: &mut BTreeMap<LocalId, LocalId>,
958        id_gen: &mut IdGen,
959    ) -> Result<(), crate::TransformError> {
960        // The stack contains pending work as `Result<LocalId, &MirRelationExpr>`, where
961        // 1. 'Ok(id)` means the identifier `id` is ready for renumbering,
962        // 2. `Err(expr)` means that the expression `expr` needs to be further processed.
963        let mut stack: Vec<Result<LocalId, _>> = vec![Err(relation)];
964        while let Some(action) = stack.pop() {
965            match action {
966                Ok(id) => {
967                    if remap.contains_key(&id) {
968                        Err(crate::TransformError::Internal(format!(
969                            "Shadowing of let binding for {:?}",
970                            id
971                        )))?;
972                    } else {
973                        remap.insert(id, LocalId::new(id_gen.allocate_id()));
974                    }
975                }
976                Err(expr) => match expr {
977                    MirRelationExpr::Let { id, value, body } => {
978                        stack.push(Err(body));
979                        stack.push(Ok(*id));
980                        stack.push(Err(value));
981                    }
982                    MirRelationExpr::LetRec {
983                        ids,
984                        values,
985                        limits: _,
986                        body,
987                    } => {
988                        stack.push(Err(body));
989                        for (id, value) in ids.iter().rev().zip_eq(values.iter().rev()) {
990                            stack.push(Ok(*id));
991                            stack.push(Err(value));
992                        }
993                    }
994                    _ => {
995                        stack.extend(expr.children().rev().map(Err));
996                    }
997                },
998            }
999        }
1000        Ok(())
1001    }
1002
1003    fn implement(
1004        relation: &mut MirRelationExpr,
1005        remap: &BTreeMap<LocalId, LocalId>,
1006    ) -> Result<(), crate::TransformError> {
1007        let mut worklist = vec![relation];
1008        while let Some(expr) = worklist.pop() {
1009            match expr {
1010                MirRelationExpr::Let { id, .. } => {
1011                    *id = *remap
1012                        .get(id)
1013                        .ok_or(crate::TransformError::IdentifierMissing(*id))?;
1014                }
1015                MirRelationExpr::LetRec { ids, .. } => {
1016                    for id in ids.iter_mut() {
1017                        *id = *remap
1018                            .get(id)
1019                            .ok_or(crate::TransformError::IdentifierMissing(*id))?;
1020                    }
1021                }
1022                MirRelationExpr::Get {
1023                    id: Id::Local(id), ..
1024                } => {
1025                    *id = *remap
1026                        .get(id)
1027                        .ok_or(crate::TransformError::IdentifierMissing(*id))?;
1028                }
1029                _ => {
1030                    // Remapped identifiers not used in these patterns.
1031                }
1032            }
1033            // The order is not critical, but behave as a stack for clarity.
1034            worklist.extend(expr.children_mut().rev());
1035        }
1036        Ok(())
1037    }
1038}