mz_transform/
fold_constants.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Replace operators on constants collections with constant collections.
11
12use std::cmp::Ordering;
13use std::collections::{BTreeMap, BTreeSet};
14use std::convert::TryInto;
15use std::iter;
16
17use mz_expr::visit::Visit;
18use mz_expr::{
19    AggregateExpr, ColumnOrder, EvalError, MirRelationExpr, MirScalarExpr, TableFunc, UnaryFunc,
20};
21use mz_repr::{Datum, Diff, RelationType, Row, RowArena};
22
23use crate::{TransformCtx, TransformError, any};
24
25/// Replace operators on constant collections with constant collections.
26#[derive(Debug)]
27pub struct FoldConstants {
28    /// An optional maximum size, after which optimization can cease.
29    ///
30    /// The `None` value here indicates no maximum size, but does not
31    /// currently guarantee that any constant expression will be reduced
32    /// to a `MirRelationExpr::Constant` variant.
33    pub limit: Option<usize>,
34}
35
36impl crate::Transform for FoldConstants {
37    fn name(&self) -> &'static str {
38        "FoldConstants"
39    }
40
41    #[mz_ore::instrument(
42        target = "optimizer",
43        level = "debug",
44        fields(path.segment = "fold_constants")
45    )]
46    fn actually_perform_transform(
47        &self,
48        relation: &mut MirRelationExpr,
49        _: &mut TransformCtx,
50    ) -> Result<(), TransformError> {
51        let mut type_stack = Vec::new();
52        let result = relation.try_visit_mut_post(&mut |e| -> Result<(), TransformError> {
53            let num_inputs = e.num_inputs();
54            let input_types = &type_stack[type_stack.len() - num_inputs..];
55            let mut relation_type = e.typ_with_input_types(input_types);
56            self.action(e, &mut relation_type)?;
57            type_stack.truncate(type_stack.len() - num_inputs);
58            type_stack.push(relation_type);
59            Ok(())
60        });
61        mz_repr::explain::trace_plan(&*relation);
62        result
63    }
64}
65
66impl FoldConstants {
67    /// Replace operators on constants collections with constant collections.
68    ///
69    /// This transform will cease optimization if it encounters constant collections
70    /// that are larger than `self.limit`, if that is set. It is not guaranteed that
71    /// a constant input within the limit will be reduced to a `Constant` variant.
72    pub fn action(
73        &self,
74        relation: &mut MirRelationExpr,
75        relation_type: &mut RelationType,
76    ) -> Result<(), TransformError> {
77        match relation {
78            MirRelationExpr::Constant { .. } => { /* handled after match */ }
79            MirRelationExpr::Get { .. } => {}
80            MirRelationExpr::Let { .. } | MirRelationExpr::LetRec { .. } => {
81                // Constant propagation through bindings is currently handled by in NormalizeLets.
82                // Maybe we should move it / replicate it here (see database-issues#5346 for context)?
83            }
84            MirRelationExpr::Reduce {
85                input,
86                group_key,
87                aggregates,
88                monotonic: _,
89                expected_group_size: _,
90            } => {
91                // Guard against evaluating an expression that may contain
92                // unmaterializable functions.
93                if group_key.iter().any(|e| e.contains_unmaterializable())
94                    || aggregates
95                        .iter()
96                        .any(|a| a.expr.contains_unmaterializable())
97                {
98                    return Ok(());
99                }
100
101                if let Some((rows, ..)) = (**input).as_const() {
102                    let new_rows = match rows {
103                        Ok(rows) => {
104                            if let Some(rows) =
105                                Self::fold_reduce_constant(group_key, aggregates, rows, self.limit)
106                            {
107                                rows
108                            } else {
109                                return Ok(());
110                            }
111                        }
112                        Err(e) => Err(e.clone()),
113                    };
114                    *relation = MirRelationExpr::Constant {
115                        rows: new_rows,
116                        typ: relation_type.clone(),
117                    };
118                }
119            }
120            MirRelationExpr::TopK {
121                input,
122                group_key,
123                order_key,
124                limit,
125                offset,
126                ..
127            } => {
128                // Only fold constants when:
129                //
130                // 1. The `limit` value is not set, or
131                // 2. The `limit` value is set to a literal x such that x >= 0.
132                //
133                // We can improve this to arbitrary expressions, but it requires
134                // more typing.
135                if any![
136                    limit.is_none(),
137                    limit.as_ref().and_then(|l| l.as_literal_int64()) >= Some(0),
138                ] {
139                    let limit = limit
140                        .as_ref()
141                        .and_then(|l| l.as_literal_int64().map(Into::into));
142                    if let Some((rows, ..)) = (**input).as_const_mut() {
143                        if let Ok(rows) = rows {
144                            Self::fold_topk_constant(group_key, order_key, &limit, offset, rows);
145                        }
146                        *relation = input.take_dangerous();
147                    }
148                }
149            }
150            MirRelationExpr::Negate { input } => {
151                if let Some((rows, ..)) = (**input).as_const_mut() {
152                    if let Ok(rows) = rows {
153                        for (_row, diff) in rows {
154                            *diff = -*diff;
155                        }
156                    }
157                    *relation = input.take_dangerous();
158                }
159            }
160            MirRelationExpr::Threshold { input } => {
161                if let Some((rows, ..)) = (**input).as_const_mut() {
162                    if let Ok(rows) = rows {
163                        rows.retain(|(_, diff)| diff.is_positive());
164                    }
165                    *relation = input.take_dangerous();
166                }
167            }
168            MirRelationExpr::Map { input, scalars } => {
169                // Guard against evaluating expression that may contain
170                // unmaterializable functions.
171                if scalars.iter().any(|e| e.contains_unmaterializable()) {
172                    return Ok(());
173                }
174
175                if let Some((rows, ..)) = (**input).as_const() {
176                    // Do not evaluate calls if:
177                    // 1. The input consist of at least one row, and
178                    // 2. The scalars is a singleton mz_panic('forced panic') call.
179                    // Instead, indicate to the caller to panic.
180                    if rows.as_ref().map_or(0, |r| r.len()) > 0 && scalars.len() == 1 {
181                        if let MirScalarExpr::CallUnary {
182                            func: UnaryFunc::Panic(_),
183                            expr,
184                        } = &scalars[0]
185                        {
186                            if let Some("forced panic") = expr.as_literal_str() {
187                                let msg = "forced panic".to_string();
188                                return Err(TransformError::CallerShouldPanic(msg));
189                            }
190                        }
191                    }
192
193                    let new_rows = match rows {
194                        Ok(rows) => rows
195                            .iter()
196                            .cloned()
197                            .map(|(input_row, diff)| {
198                                // TODO: reduce allocations to zero.
199                                let mut unpacked = input_row.unpack();
200                                let temp_storage = RowArena::new();
201                                for scalar in scalars.iter() {
202                                    unpacked.push(scalar.eval(&unpacked, &temp_storage)?)
203                                }
204                                Ok::<_, EvalError>((Row::pack_slice(&unpacked), diff))
205                            })
206                            .collect::<Result<_, _>>(),
207                        Err(e) => Err(e.clone()),
208                    };
209                    *relation = MirRelationExpr::Constant {
210                        rows: new_rows,
211                        typ: relation_type.clone(),
212                    };
213                }
214            }
215            MirRelationExpr::FlatMap { input, func, exprs } => {
216                // Guard against evaluating expression that may contain unmaterializable functions.
217                if exprs.iter().any(|e| e.contains_unmaterializable()) {
218                    return Ok(());
219                }
220
221                if let Some((rows, ..)) = (**input).as_const() {
222                    let new_rows = match rows {
223                        Ok(rows) => Self::fold_flat_map_constant(func, exprs, rows, self.limit),
224                        Err(e) => Err(e.clone()),
225                    };
226                    match new_rows {
227                        Ok(None) => {}
228                        Ok(Some(rows)) => {
229                            *relation = MirRelationExpr::Constant {
230                                rows: Ok(rows),
231                                typ: relation_type.clone(),
232                            };
233                        }
234                        Err(err) => {
235                            *relation = MirRelationExpr::Constant {
236                                rows: Err(err),
237                                typ: relation_type.clone(),
238                            };
239                        }
240                    };
241                }
242            }
243            MirRelationExpr::Filter { input, predicates } => {
244                // Guard against evaluating expression that may contain
245                // unmaterializable function calls.
246                if predicates.iter().any(|e| e.contains_unmaterializable()) {
247                    return Ok(());
248                }
249
250                // If any predicate is false, reduce to the empty collection.
251                if predicates
252                    .iter()
253                    .any(|p| p.is_literal_false() || p.is_literal_null())
254                {
255                    relation.take_safely(Some(relation_type.clone()));
256                } else if let Some((rows, ..)) = (**input).as_const() {
257                    // Evaluate errors last, to reduce risk of spurious errors.
258                    predicates.sort_by_key(|p| p.is_literal_err());
259                    let new_rows = match rows {
260                        Ok(rows) => Self::fold_filter_constant(predicates, rows),
261                        Err(e) => Err(e.clone()),
262                    };
263                    *relation = MirRelationExpr::Constant {
264                        rows: new_rows,
265                        typ: relation_type.clone(),
266                    };
267                }
268            }
269            MirRelationExpr::Project { input, outputs } => {
270                if let Some((rows, ..)) = (**input).as_const() {
271                    let mut row_buf = Row::default();
272                    let new_rows = match rows {
273                        Ok(rows) => Ok(rows
274                            .iter()
275                            .map(|(input_row, diff)| {
276                                // TODO: reduce allocations to zero.
277                                let datums = input_row.unpack();
278                                row_buf.packer().extend(outputs.iter().map(|i| &datums[*i]));
279                                (row_buf.clone(), *diff)
280                            })
281                            .collect()),
282                        Err(e) => Err(e.clone()),
283                    };
284                    *relation = MirRelationExpr::Constant {
285                        rows: new_rows,
286                        typ: relation_type.clone(),
287                    };
288                }
289            }
290            MirRelationExpr::Join {
291                inputs,
292                equivalences,
293                ..
294            } => {
295                if inputs.iter().any(|e| e.is_empty()) {
296                    relation.take_safely(Some(relation_type.clone()));
297                } else if let Some(e) = inputs.iter().find_map(|i| i.as_const_err()) {
298                    *relation = MirRelationExpr::Constant {
299                        rows: Err(e.clone()),
300                        typ: relation_type.clone(),
301                    };
302                } else if inputs
303                    .iter()
304                    .all(|i| matches!(i.as_const(), Some((Ok(_), ..))))
305                {
306                    // Guard against evaluating expression that may contain unmaterializable functions.
307                    if equivalences
308                        .iter()
309                        .any(|equiv| equiv.iter().any(|e| e.contains_unmaterializable()))
310                    {
311                        return Ok(());
312                    }
313
314                    // We can fold all constant inputs together, but must apply the constraints to restrict them.
315                    // We start with a single 0-ary row.
316                    let mut old_rows = vec![(Row::pack::<_, Datum>(None), Diff::ONE)];
317                    let mut row_buf = Row::default();
318                    for input in inputs.iter() {
319                        if let Some((Ok(rows), ..)) = input.as_const() {
320                            if let Some(limit) = self.limit {
321                                if old_rows.len() * rows.len() > limit {
322                                    // Bail out if we have produced too many rows.
323                                    // TODO: progressively apply equivalences to narrow this count
324                                    // as we go, rather than at the end.
325                                    return Ok(());
326                                }
327                            }
328                            let mut next_rows = Vec::new();
329                            for (old_row, old_count) in old_rows {
330                                for (new_row, new_count) in rows.iter() {
331                                    let mut packer = row_buf.packer();
332                                    packer.extend_by_row(&old_row);
333                                    packer.extend_by_row(new_row);
334                                    next_rows.push((row_buf.clone(), old_count * *new_count));
335                                }
336                            }
337                            old_rows = next_rows;
338                        }
339                    }
340
341                    // Now throw away anything that doesn't satisfy the requisite constraints.
342                    let mut datum_vec = mz_repr::DatumVec::new();
343                    old_rows.retain(|(row, _count)| {
344                        let datums = datum_vec.borrow_with(row);
345                        let temp_storage = RowArena::new();
346                        equivalences.iter().all(|equivalence| {
347                            let mut values =
348                                equivalence.iter().map(|e| e.eval(&datums, &temp_storage));
349                            if let Some(value) = values.next() {
350                                values.all(|v| v == value)
351                            } else {
352                                true
353                            }
354                        })
355                    });
356
357                    *relation = MirRelationExpr::Constant {
358                        rows: Ok(old_rows),
359                        typ: relation_type.clone(),
360                    };
361                }
362                // TODO: General constant folding for all constant inputs.
363            }
364            MirRelationExpr::Union { base, inputs } => {
365                if let Some(e) = iter::once(&mut **base)
366                    .chain(&mut *inputs)
367                    .find_map(|i| i.as_const_err())
368                {
369                    *relation = MirRelationExpr::Constant {
370                        rows: Err(e.clone()),
371                        typ: relation_type.clone(),
372                    };
373                } else {
374                    let mut rows = vec![];
375                    let mut new_inputs = vec![];
376
377                    for input in iter::once(&mut **base).chain(&mut *inputs) {
378                        if let Some((Ok(rs), ..)) = input.as_const() {
379                            rows.extend(rs.clone());
380                        } else {
381                            new_inputs.push(input.clone())
382                        }
383                    }
384                    if !rows.is_empty() {
385                        new_inputs.push(MirRelationExpr::Constant {
386                            rows: Ok(rows),
387                            typ: relation_type.clone(),
388                        });
389                    }
390
391                    *relation = MirRelationExpr::union_many(new_inputs, relation_type.clone());
392                }
393            }
394            MirRelationExpr::ArrangeBy { .. } => {
395                // Don't fold ArrangeBys, because that could result in unarranged Delta join inputs.
396                // See also the comment on `MirRelationExpr::Constant`.
397            }
398        }
399
400        // This transformation maintains the invariant that all constant nodes
401        // will be consolidated. We have to make a separate check for constant
402        // nodes here, since the match arm above might install new constant
403        // nodes.
404        if let Some((Ok(rows), typ)) = relation.as_const_mut() {
405            // Reduce down to canonical representation.
406            differential_dataflow::consolidation::consolidate(rows);
407
408            // Re-establish nullability of each column.
409            for col_type in typ.column_types.iter_mut() {
410                col_type.nullable = false;
411            }
412            for (row, _) in rows.iter_mut() {
413                for (index, datum) in row.iter().enumerate() {
414                    if datum.is_null() {
415                        typ.column_types[index].nullable = true;
416                    }
417                }
418            }
419            *relation_type = typ.clone();
420        }
421
422        Ok(())
423    }
424
425    // TODO(benesch): remove this once this function no longer makes use of
426    // potentially dangerous `as` conversions.
427    #[allow(clippy::as_conversions)]
428    fn fold_reduce_constant(
429        group_key: &[MirScalarExpr],
430        aggregates: &[AggregateExpr],
431        rows: &[(Row, Diff)],
432        limit: Option<usize>,
433    ) -> Option<Result<Vec<(Row, Diff)>, EvalError>> {
434        // Build a map from `group_key` to `Vec<Vec<an, ..., a1>>)`,
435        // where `an` is the input to the nth aggregate function in
436        // `aggregates`.
437        let mut groups = BTreeMap::new();
438        let temp_storage2 = RowArena::new();
439        let mut row_buf = Row::default();
440        let mut limit_remaining =
441            limit.map_or(Diff::MAX, |limit| Diff::try_from(limit).expect("must fit"));
442        for (row, diff) in rows {
443            // We currently maintain the invariant that any negative
444            // multiplicities will be consolidated away before they
445            // arrive at a reduce.
446
447            if *diff <= Diff::ZERO {
448                return Some(Err(EvalError::InvalidParameterValue(
449                    "constant folding encountered reduce on collection with non-positive multiplicities".into()
450                )));
451            }
452
453            if limit_remaining < *diff {
454                return None;
455            }
456            limit_remaining -= diff;
457
458            let datums = row.unpack();
459            let temp_storage = RowArena::new();
460            let key = match group_key
461                .iter()
462                .map(|e| e.eval(&datums, &temp_storage2))
463                .collect::<Result<Vec<_>, _>>()
464            {
465                Ok(key) => key,
466                Err(e) => return Some(Err(e)),
467            };
468            let val = match aggregates
469                .iter()
470                .map(|agg| {
471                    row_buf
472                        .packer()
473                        .extend([agg.expr.eval(&datums, &temp_storage)?]);
474                    Ok::<_, EvalError>(row_buf.clone())
475                })
476                .collect::<Result<Vec<_>, _>>()
477            {
478                Ok(val) => val,
479                Err(e) => return Some(Err(e)),
480            };
481            let entry = groups.entry(key).or_insert_with(Vec::new);
482            for _ in 0..diff.into_inner() {
483                entry.push(val.clone());
484            }
485        }
486
487        // For each group, apply the aggregate function to the rows
488        // in the group. The output is
489        // `Vec<Vec<k1, ..., kn, r1, ..., rn>>`
490        // where kn is the nth column of the key and rn is the
491        // result of the nth aggregate function for that group.
492        let new_rows = groups
493            .into_iter()
494            .map({
495                let mut row_buf = Row::default();
496                move |(key, vals)| {
497                    let temp_storage = RowArena::new();
498                    row_buf.packer().extend(key.into_iter().chain(
499                        aggregates.iter().enumerate().map(|(i, agg)| {
500                            if agg.distinct {
501                                agg.func.eval(
502                                    vals.iter()
503                                        .map(|val| val[i].unpack_first())
504                                        .collect::<BTreeSet<_>>(),
505                                    &temp_storage,
506                                )
507                            } else {
508                                agg.func.eval(
509                                    vals.iter().map(|val| val[i].unpack_first()),
510                                    &temp_storage,
511                                )
512                            }
513                        }),
514                    ));
515                    (row_buf.clone(), Diff::ONE)
516                }
517            })
518            .collect();
519        Some(Ok(new_rows))
520    }
521
522    fn fold_topk_constant<'a>(
523        group_key: &[usize],
524        order_key: &[ColumnOrder],
525        limit: &Option<Diff>,
526        offset: &usize,
527        rows: &'a mut [(Row, Diff)],
528    ) {
529        // helper functions for comparing elements by order_key and group_key
530        let mut lhs_datum_vec = mz_repr::DatumVec::new();
531        let mut rhs_datum_vec = mz_repr::DatumVec::new();
532        let mut cmp_order_key = |lhs: &(Row, Diff), rhs: &(Row, Diff)| {
533            let lhs_datums = &lhs_datum_vec.borrow_with(&lhs.0);
534            let rhs_datums = &rhs_datum_vec.borrow_with(&rhs.0);
535            mz_expr::compare_columns(order_key, lhs_datums, rhs_datums, || lhs.cmp(rhs))
536        };
537        let mut cmp_group_key = {
538            let group_key = group_key
539                .iter()
540                .map(|column| ColumnOrder {
541                    column: *column,
542                    // desc and nulls_last don't matter: the sorting by cmp_group_key is just to
543                    // make the elements of each group appear next to each other, but the order of
544                    // groups doesn't matter.
545                    desc: false,
546                    nulls_last: false,
547                })
548                .collect::<Vec<ColumnOrder>>();
549            let mut lhs_datum_vec = mz_repr::DatumVec::new();
550            let mut rhs_datum_vec = mz_repr::DatumVec::new();
551            move |lhs: &(Row, Diff), rhs: &(Row, Diff)| {
552                let lhs_datums = &lhs_datum_vec.borrow_with(&lhs.0);
553                let rhs_datums = &rhs_datum_vec.borrow_with(&rhs.0);
554                mz_expr::compare_columns(&group_key, lhs_datums, rhs_datums, || Ordering::Equal)
555            }
556        };
557
558        // compute Ordering based on the sort_key, otherwise consider all rows equal
559        rows.sort_by(&mut cmp_order_key);
560
561        // sort by the grouping key if not empty, keeping order_key as a secondary sort
562        if !group_key.is_empty() {
563            rows.sort_by(&mut cmp_group_key);
564        };
565
566        let mut same_group_key =
567            |lhs: &(Row, Diff), rhs: &(Row, Diff)| cmp_group_key(lhs, rhs) == Ordering::Equal;
568
569        let mut cursor = 0;
570        while cursor < rows.len() {
571            // first, reset the remaining limit and offset for the current group
572            let mut offset_rem: Diff = offset.clone().try_into().unwrap();
573            let mut limit_rem: Option<Diff> = limit.clone();
574
575            let mut finger = cursor;
576            while finger < rows.len() && same_group_key(&rows[cursor], &rows[finger]) {
577                if rows[finger].1.is_negative() {
578                    // ignore elements with negative diff
579                    rows[finger].1 = Diff::ZERO;
580                } else {
581                    // determine how many of the leading rows to ignore,
582                    // then decrement the diff and remaining offset by that number
583                    let rows_to_ignore = std::cmp::min(offset_rem, rows[finger].1);
584                    rows[finger].1 -= rows_to_ignore;
585                    offset_rem -= rows_to_ignore;
586                    // determine how many of the remaining rows to retain,
587                    // then update the diff and decrement the remaining limit by that number
588                    if let Some(limit_rem) = &mut limit_rem {
589                        let rows_to_retain = std::cmp::min(*limit_rem, rows[finger].1);
590                        rows[finger].1 = rows_to_retain;
591                        *limit_rem -= rows_to_retain;
592                    }
593                }
594                finger += 1;
595            }
596            cursor = finger;
597        }
598    }
599
600    fn fold_flat_map_constant(
601        func: &TableFunc,
602        exprs: &[MirScalarExpr],
603        rows: &[(Row, Diff)],
604        limit: Option<usize>,
605    ) -> Result<Option<Vec<(Row, Diff)>>, EvalError> {
606        // We cannot exceed `usize::MAX` in any array, so this is a fine upper bound.
607        let limit = limit.unwrap_or(usize::MAX);
608        let mut new_rows = Vec::new();
609        let mut row_buf = Row::default();
610        let mut datum_vec = mz_repr::DatumVec::new();
611        for (input_row, diff) in rows {
612            let datums = datum_vec.borrow_with(input_row);
613            let temp_storage = RowArena::new();
614            let datums = exprs
615                .iter()
616                .map(|expr| expr.eval(&datums, &temp_storage))
617                .collect::<Result<Vec<_>, _>>()?;
618            let mut output_rows = func.eval(&datums, &temp_storage)?.fuse();
619            for (output_row, diff2) in (&mut output_rows).take(limit - new_rows.len()) {
620                let mut packer = row_buf.packer();
621                packer.extend_by_row(input_row);
622                packer.extend_by_row(&output_row);
623                new_rows.push((row_buf.clone(), diff2 * *diff))
624            }
625            // If we still have records to enumerate, but dropped out of the iteration,
626            // it means we have exhausted `limit` and should stop.
627            if output_rows.next() != None {
628                return Ok(None);
629            }
630        }
631        Ok(Some(new_rows))
632    }
633
634    fn fold_filter_constant(
635        predicates: &[MirScalarExpr],
636        rows: &[(Row, Diff)],
637    ) -> Result<Vec<(Row, Diff)>, EvalError> {
638        let mut new_rows = Vec::new();
639        let mut datum_vec = mz_repr::DatumVec::new();
640        'outer: for (row, diff) in rows {
641            let datums = datum_vec.borrow_with(row);
642            let temp_storage = RowArena::new();
643            for p in &*predicates {
644                if p.eval(&datums, &temp_storage)? != Datum::True {
645                    continue 'outer;
646                }
647            }
648            new_rows.push((row.clone(), *diff))
649        }
650        Ok(new_rows)
651    }
652}