Skip to main content

mz_transform/
fold_constants.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Replace operators on constants collections with constant collections.
11
12use std::cmp::Ordering;
13use std::collections::{BTreeMap, BTreeSet};
14use std::convert::TryInto;
15use std::iter;
16
17use mz_expr::visit::Visit;
18use mz_expr::{
19    AggregateExpr, ColumnOrder, EvalError, MirRelationExpr, MirScalarExpr, RowComparator,
20    TableFunc, UnaryFunc,
21};
22use mz_repr::{Datum, Diff, ReprRelationType, Row, RowArena};
23
24use crate::{TransformCtx, TransformError, any};
25
26/// Replace operators on constant collections with constant collections.
27#[derive(Debug)]
28pub struct FoldConstants {
29    /// An optional maximum size, after which optimization can cease.
30    ///
31    /// The `None` value here indicates no maximum size, but does not
32    /// currently guarantee that any constant expression will be reduced
33    /// to a `MirRelationExpr::Constant` variant.
34    pub limit: Option<usize>,
35}
36
37impl crate::Transform for FoldConstants {
38    fn name(&self) -> &'static str {
39        "FoldConstants"
40    }
41
42    #[mz_ore::instrument(
43        target = "optimizer",
44        level = "debug",
45        fields(path.segment = "fold_constants")
46    )]
47    fn actually_perform_transform(
48        &self,
49        relation: &mut MirRelationExpr,
50        _: &mut TransformCtx,
51    ) -> Result<(), TransformError> {
52        let mut type_stack = Vec::new();
53        let result = relation.try_visit_mut_post(&mut |e| -> Result<(), TransformError> {
54            let num_inputs = e.num_inputs();
55            let input_types = &type_stack[type_stack.len() - num_inputs..];
56            let mut relation_type = e.typ_with_input_types(input_types);
57            self.action(e, &mut relation_type)?;
58            type_stack.truncate(type_stack.len() - num_inputs);
59            type_stack.push(relation_type);
60            Ok(())
61        });
62        mz_repr::explain::trace_plan(&*relation);
63        result
64    }
65}
66
67impl FoldConstants {
68    /// Replace operators on constants collections with constant collections.
69    ///
70    /// This transform will cease optimization if it encounters constant collections
71    /// that are larger than `self.limit`, if that is set. It is not guaranteed that
72    /// a constant input within the limit will be reduced to a `Constant` variant.
73    pub fn action(
74        &self,
75        relation: &mut MirRelationExpr,
76        relation_type: &mut ReprRelationType,
77    ) -> Result<(), TransformError> {
78        match relation {
79            MirRelationExpr::Constant { .. } => { /* handled after match */ }
80            MirRelationExpr::Get { .. } => {}
81            MirRelationExpr::Let { .. } | MirRelationExpr::LetRec { .. } => {
82                // Constant propagation through bindings is currently handled by in NormalizeLets.
83                // Maybe we should move it / replicate it here (see database-issues#5346 for context)?
84            }
85            MirRelationExpr::Reduce {
86                input,
87                group_key,
88                aggregates,
89                monotonic: _,
90                expected_group_size: _,
91            } => {
92                // Guard against evaluating an expression that may contain
93                // unmaterializable functions.
94                if group_key.iter().any(|e| e.contains_unmaterializable())
95                    || aggregates
96                        .iter()
97                        .any(|a| a.expr.contains_unmaterializable())
98                {
99                    return Ok(());
100                }
101
102                if let Some((rows, ..)) = (**input).as_const() {
103                    let new_rows = match rows {
104                        Ok(rows) => {
105                            if let Some(rows) =
106                                Self::fold_reduce_constant(group_key, aggregates, rows, self.limit)
107                            {
108                                rows
109                            } else {
110                                return Ok(());
111                            }
112                        }
113                        Err(e) => Err(e.clone()),
114                    };
115                    *relation = MirRelationExpr::Constant {
116                        rows: new_rows,
117                        typ: relation_type.clone(),
118                    };
119                }
120            }
121            MirRelationExpr::TopK {
122                input,
123                group_key,
124                order_key,
125                limit,
126                offset,
127                ..
128            } => {
129                // Only fold constants when:
130                //
131                // 1. The `limit` value is not set, or
132                // 2. The `limit` value is set to a literal x such that x >= 0.
133                //
134                // We can improve this to arbitrary expressions, but it requires
135                // more typing.
136                if any![
137                    limit.is_none(),
138                    limit.as_ref().and_then(|l| l.as_literal_int64()) >= Some(0),
139                ] {
140                    let limit = limit
141                        .as_ref()
142                        .and_then(|l| l.as_literal_int64().map(Into::into));
143                    if let Some((rows, ..)) = (**input).as_const_mut() {
144                        if let Ok(rows) = rows {
145                            Self::fold_topk_constant(group_key, order_key, &limit, offset, rows);
146                        }
147                        *relation = input.take_dangerous();
148                    }
149                }
150            }
151            MirRelationExpr::Negate { input } => {
152                if let Some((rows, ..)) = (**input).as_const_mut() {
153                    if let Ok(rows) = rows {
154                        for (_row, diff) in rows {
155                            *diff = -*diff;
156                        }
157                    }
158                    *relation = input.take_dangerous();
159                }
160            }
161            MirRelationExpr::Threshold { input } => {
162                if let Some((rows, ..)) = (**input).as_const_mut() {
163                    if let Ok(rows) = rows {
164                        rows.retain(|(_, diff)| diff.is_positive());
165                    }
166                    *relation = input.take_dangerous();
167                }
168            }
169            MirRelationExpr::Map { input, scalars } => {
170                // Guard against evaluating expression that may contain
171                // unmaterializable functions.
172                if scalars.iter().any(|e| e.contains_unmaterializable()) {
173                    return Ok(());
174                }
175
176                if let Some((rows, ..)) = (**input).as_const() {
177                    // Do not evaluate calls if:
178                    // 1. The input consist of at least one row, and
179                    // 2. The scalars is a singleton mz_panic('forced panic') call.
180                    // Instead, indicate to the caller to panic.
181                    if rows.as_ref().map_or(0, |r| r.len()) > 0 && scalars.len() == 1 {
182                        if let MirScalarExpr::CallUnary {
183                            func: UnaryFunc::Panic(_),
184                            expr,
185                        } = &scalars[0]
186                        {
187                            if let Some("forced panic") = expr.as_literal_str() {
188                                let msg = "forced panic".to_string();
189                                return Err(TransformError::CallerShouldPanic(msg));
190                            }
191                        }
192                    }
193
194                    let new_rows = match rows {
195                        Ok(rows) => rows
196                            .iter()
197                            .map(|(input_row, diff)| {
198                                // TODO: reduce allocations to zero.
199                                let mut unpacked = input_row.unpack();
200                                let temp_storage = RowArena::new();
201                                for scalar in scalars.iter() {
202                                    unpacked.push(scalar.eval(&unpacked, &temp_storage)?)
203                                }
204                                Ok::<_, EvalError>((Row::pack_slice(&unpacked), *diff))
205                            })
206                            .collect::<Result<_, _>>(),
207                        Err(e) => Err(e.clone()),
208                    };
209                    *relation = MirRelationExpr::Constant {
210                        rows: new_rows,
211                        typ: relation_type.clone(),
212                    };
213                }
214            }
215            MirRelationExpr::FlatMap { input, func, exprs } => {
216                // Guard against evaluating expression that may contain unmaterializable functions.
217                if exprs.iter().any(|e| e.contains_unmaterializable()) {
218                    return Ok(());
219                }
220
221                if let Some((rows, ..)) = (**input).as_const() {
222                    let new_rows = match rows {
223                        Ok(rows) => Self::fold_flat_map_constant(func, exprs, rows, self.limit),
224                        Err(e) => Err(e.clone()),
225                    };
226                    match new_rows {
227                        Ok(None) => {}
228                        Ok(Some(rows)) => {
229                            *relation = MirRelationExpr::Constant {
230                                rows: Ok(rows),
231                                typ: relation_type.clone(),
232                            };
233                        }
234                        Err(err) => {
235                            *relation = MirRelationExpr::Constant {
236                                rows: Err(err),
237                                typ: relation_type.clone(),
238                            };
239                        }
240                    };
241                }
242            }
243            MirRelationExpr::Filter { input, predicates } => {
244                // Guard against evaluating expression that may contain
245                // unmaterializable function calls.
246                if predicates.iter().any(|e| e.contains_unmaterializable()) {
247                    return Ok(());
248                }
249
250                // If any predicate is false, reduce to the empty collection.
251                if predicates
252                    .iter()
253                    .any(|p| p.is_literal_false() || p.is_literal_null())
254                {
255                    relation.take_safely(Some(relation_type.clone()));
256                } else if let Some((rows, ..)) = (**input).as_const() {
257                    // Evaluate errors last, to reduce risk of spurious errors.
258                    predicates.sort_by_key(|p| p.is_literal_err());
259                    let new_rows = match rows {
260                        Ok(rows) => Self::fold_filter_constant(predicates, rows),
261                        Err(e) => Err(e.clone()),
262                    };
263                    *relation = MirRelationExpr::Constant {
264                        rows: new_rows,
265                        typ: relation_type.clone(),
266                    };
267                }
268            }
269            MirRelationExpr::Project { input, outputs } => {
270                if let Some((rows, ..)) = (**input).as_const() {
271                    let mut row_buf = Row::default();
272                    let new_rows = match rows {
273                        Ok(rows) => Ok(rows
274                            .iter()
275                            .map(|(input_row, diff)| {
276                                // TODO: reduce allocations to zero.
277                                let datums = input_row.unpack();
278                                row_buf.packer().extend(outputs.iter().map(|i| &datums[*i]));
279                                (row_buf.clone(), *diff)
280                            })
281                            .collect()),
282                        Err(e) => Err(e.clone()),
283                    };
284                    *relation = MirRelationExpr::Constant {
285                        rows: new_rows,
286                        typ: relation_type.clone(),
287                    };
288                }
289            }
290            MirRelationExpr::Join {
291                inputs,
292                equivalences,
293                ..
294            } => {
295                if inputs.iter().any(|e| e.is_empty()) {
296                    relation.take_safely(Some(relation_type.clone()));
297                } else if let Some(e) = inputs.iter().find_map(|i| i.as_const_err()) {
298                    *relation = MirRelationExpr::Constant {
299                        rows: Err(e.clone()),
300                        typ: relation_type.clone(),
301                    };
302                } else if inputs
303                    .iter()
304                    .all(|i| matches!(i.as_const(), Some((Ok(_), ..))))
305                {
306                    // Guard against evaluating expression that may contain unmaterializable functions.
307                    if equivalences
308                        .iter()
309                        .any(|equiv| equiv.iter().any(|e| e.contains_unmaterializable()))
310                    {
311                        return Ok(());
312                    }
313
314                    // We can fold all constant inputs together, but must apply the constraints to restrict them.
315                    // We start with a single 0-ary row.
316                    let mut old_rows = vec![(Row::pack::<_, Datum>(None), Diff::ONE)];
317                    let mut row_buf = Row::default();
318                    for input in inputs.iter() {
319                        if let Some((Ok(rows), ..)) = input.as_const() {
320                            if let Some(limit) = self.limit {
321                                if old_rows.len() * rows.len() > limit {
322                                    // Bail out if we have produced too many rows.
323                                    // TODO: progressively apply equivalences to narrow this count
324                                    // as we go, rather than at the end.
325                                    return Ok(());
326                                }
327                            }
328                            let mut next_rows = Vec::new();
329                            for (old_row, old_count) in old_rows {
330                                for (new_row, new_count) in rows.iter() {
331                                    let mut packer = row_buf.packer();
332                                    packer.extend_by_row(&old_row);
333                                    packer.extend_by_row(new_row);
334                                    next_rows.push((row_buf.clone(), old_count * *new_count));
335                                }
336                            }
337                            old_rows = next_rows;
338                        }
339                    }
340
341                    // Now throw away anything that doesn't satisfy the requisite constraints.
342                    let mut datum_vec = mz_repr::DatumVec::new();
343                    old_rows.retain(|(row, _count)| {
344                        let datums = datum_vec.borrow_with(row);
345                        let temp_storage = RowArena::new();
346                        equivalences.iter().all(|equivalence| {
347                            let mut values =
348                                equivalence.iter().map(|e| e.eval(&datums, &temp_storage));
349                            if let Some(value) = values.next() {
350                                values.all(|v| v == value)
351                            } else {
352                                true
353                            }
354                        })
355                    });
356
357                    *relation = MirRelationExpr::Constant {
358                        rows: Ok(old_rows),
359                        typ: relation_type.clone(),
360                    };
361                }
362                // TODO: General constant folding for all constant inputs.
363            }
364            MirRelationExpr::Union { base, inputs } => {
365                if let Some(e) = iter::once(&mut **base)
366                    .chain(&mut *inputs)
367                    .find_map(|i| i.as_const_err())
368                {
369                    *relation = MirRelationExpr::Constant {
370                        rows: Err(e.clone()),
371                        typ: relation_type.clone(),
372                    };
373                } else {
374                    let mut rows = vec![];
375                    let mut new_inputs = vec![];
376
377                    for input in iter::once(&mut **base).chain(&mut *inputs) {
378                        if let Some((Ok(rs), ..)) = input.as_const() {
379                            rows.extend(rs.clone());
380                        } else {
381                            new_inputs.push(input.clone())
382                        }
383                    }
384                    if !rows.is_empty() {
385                        new_inputs.push(MirRelationExpr::Constant {
386                            rows: Ok(rows),
387                            typ: relation_type.clone(),
388                        });
389                    }
390
391                    *relation = MirRelationExpr::union_many(new_inputs, relation_type.clone());
392                }
393            }
394            MirRelationExpr::ArrangeBy { .. } => {
395                // Don't fold ArrangeBys, because that could result in unarranged Delta join inputs.
396                // See also the comment on `MirRelationExpr::Constant`.
397            }
398        }
399
400        // This transformation maintains the invariant that all constant nodes
401        // will be consolidated. We have to make a separate check for constant
402        // nodes here, since the match arm above might install new constant
403        // nodes.
404        if let Some((Ok(rows), typ)) = relation.as_const_mut() {
405            // Reduce down to canonical representation.
406            differential_dataflow::consolidation::consolidate(rows);
407
408            // Re-establish nullability of each column.
409            for col_type in typ.column_types.iter_mut() {
410                col_type.nullable = false;
411            }
412            for (row, _) in rows.iter_mut() {
413                for (index, datum) in row.iter().enumerate() {
414                    if datum.is_null() {
415                        typ.column_types[index].nullable = true;
416                    }
417                }
418            }
419            *relation_type = typ.clone();
420        }
421
422        Ok(())
423    }
424
425    // TODO(benesch): remove this once this function no longer makes use of
426    // potentially dangerous `as` conversions.
427    #[allow(clippy::as_conversions)]
428    fn fold_reduce_constant(
429        group_key: &[MirScalarExpr],
430        aggregates: &[AggregateExpr],
431        rows: &[(Row, Diff)],
432        limit: Option<usize>,
433    ) -> Option<Result<Vec<(Row, Diff)>, EvalError>> {
434        // Build a map from `group_key` to `Vec<Vec<an, ..., a1>>)`,
435        // where `an` is the input to the nth aggregate function in
436        // `aggregates`.
437        let mut groups = BTreeMap::new();
438        let temp_storage2 = RowArena::new();
439        let mut row_buf = Row::default();
440        let mut limit_remaining =
441            limit.map_or(Diff::MAX, |limit| Diff::try_from(limit).expect("must fit"));
442        for (row, diff) in rows {
443            // We currently maintain the invariant that any negative
444            // multiplicities will be consolidated away before they
445            // arrive at a reduce.
446
447            if *diff <= Diff::ZERO {
448                return Some(Err(EvalError::InvalidParameterValue(
449                    "constant folding encountered reduce on collection with non-positive multiplicities".into()
450                )));
451            }
452
453            if limit_remaining < *diff {
454                return None;
455            }
456            limit_remaining -= diff;
457
458            let datums = row.unpack();
459            let temp_storage = RowArena::new();
460            let key = match group_key
461                .iter()
462                .map(|e| e.eval(&datums, &temp_storage2))
463                .collect::<Result<Vec<_>, _>>()
464            {
465                Ok(key) => key,
466                Err(e) => return Some(Err(e)),
467            };
468            let val = match aggregates
469                .iter()
470                .map(|agg| {
471                    row_buf
472                        .packer()
473                        .extend([agg.expr.eval(&datums, &temp_storage)?]);
474                    Ok::<_, EvalError>(row_buf.clone())
475                })
476                .collect::<Result<Vec<_>, _>>()
477            {
478                Ok(val) => val,
479                Err(e) => return Some(Err(e)),
480            };
481            let entry = groups.entry(key).or_insert_with(Vec::new);
482            for _ in 0..diff.into_inner() {
483                entry.push(val.clone());
484            }
485        }
486
487        // For each group, apply the aggregate function to the rows
488        // in the group. The output is
489        // `Vec<Vec<k1, ..., kn, r1, ..., rn>>`
490        // where kn is the nth column of the key and rn is the
491        // result of the nth aggregate function for that group.
492        let new_rows = groups
493            .into_iter()
494            .map({
495                let mut row_buf = Row::default();
496                move |(key, vals)| {
497                    let temp_storage = RowArena::new();
498                    row_buf.packer().extend(key.into_iter().chain(
499                        aggregates.iter().enumerate().map(|(i, agg)| {
500                            if agg.distinct {
501                                agg.func.eval(
502                                    vals.iter()
503                                        .map(|val| val[i].unpack_first())
504                                        .collect::<BTreeSet<_>>(),
505                                    &temp_storage,
506                                )
507                            } else {
508                                agg.func.eval(
509                                    vals.iter().map(|val| val[i].unpack_first()),
510                                    &temp_storage,
511                                )
512                            }
513                        }),
514                    ));
515                    (row_buf.clone(), Diff::ONE)
516                }
517            })
518            .collect();
519        Some(Ok(new_rows))
520    }
521
522    fn fold_topk_constant<'a>(
523        group_key: &[usize],
524        order_key: &[ColumnOrder],
525        limit: &Option<Diff>,
526        offset: &usize,
527        rows: &'a mut [(Row, Diff)],
528    ) {
529        // helper functions for comparing elements by order_key and group_key
530        let comparator = RowComparator::new(order_key);
531
532        let mut cmp_order_key = |lhs: &(Row, Diff), rhs: &(Row, Diff)| {
533            comparator.compare_rows(&lhs.0, &rhs.0, || lhs.cmp(rhs))
534        };
535        let mut cmp_group_key = {
536            let group_key = group_key
537                .iter()
538                .map(|column| ColumnOrder {
539                    column: *column,
540                    // desc and nulls_last don't matter: the sorting by cmp_group_key is just to
541                    // make the elements of each group appear next to each other, but the order of
542                    // groups doesn't matter.
543                    desc: false,
544                    nulls_last: false,
545                })
546                .collect::<Vec<ColumnOrder>>();
547            let comparator = RowComparator::new(group_key);
548            move |lhs: &(Row, Diff), rhs: &(Row, Diff)| {
549                comparator.compare_rows(&lhs.0, &rhs.0, || Ordering::Equal)
550            }
551        };
552
553        // compute Ordering based on the sort_key, otherwise consider all rows equal
554        rows.sort_by(&mut cmp_order_key);
555
556        // sort by the grouping key if not empty, keeping order_key as a secondary sort
557        if !group_key.is_empty() {
558            rows.sort_by(&mut cmp_group_key);
559        };
560
561        let same_group_key =
562            |lhs: &(Row, Diff), rhs: &(Row, Diff)| cmp_group_key(lhs, rhs) == Ordering::Equal;
563
564        let mut cursor = 0;
565        while cursor < rows.len() {
566            // first, reset the remaining limit and offset for the current group
567            let mut offset_rem: Diff = offset.clone().try_into().unwrap();
568            let mut limit_rem: Option<Diff> = limit.clone();
569
570            let mut finger = cursor;
571            while finger < rows.len() && same_group_key(&rows[cursor], &rows[finger]) {
572                if rows[finger].1.is_negative() {
573                    // ignore elements with negative diff
574                    rows[finger].1 = Diff::ZERO;
575                } else {
576                    // determine how many of the leading rows to ignore,
577                    // then decrement the diff and remaining offset by that number
578                    let rows_to_ignore = std::cmp::min(offset_rem, rows[finger].1);
579                    rows[finger].1 -= rows_to_ignore;
580                    offset_rem -= rows_to_ignore;
581                    // determine how many of the remaining rows to retain,
582                    // then update the diff and decrement the remaining limit by that number
583                    if let Some(limit_rem) = &mut limit_rem {
584                        let rows_to_retain = std::cmp::min(*limit_rem, rows[finger].1);
585                        rows[finger].1 = rows_to_retain;
586                        *limit_rem -= rows_to_retain;
587                    }
588                }
589                finger += 1;
590            }
591            cursor = finger;
592        }
593    }
594
595    fn fold_flat_map_constant(
596        func: &TableFunc,
597        exprs: &[MirScalarExpr],
598        rows: &[(Row, Diff)],
599        limit: Option<usize>,
600    ) -> Result<Option<Vec<(Row, Diff)>>, EvalError> {
601        // We cannot exceed `usize::MAX` in any array, so this is a fine upper bound.
602        let limit = limit.unwrap_or(usize::MAX);
603        let mut new_rows = Vec::new();
604        let mut row_buf = Row::default();
605        let mut datum_vec = mz_repr::DatumVec::new();
606        for (input_row, diff) in rows {
607            let datums = datum_vec.borrow_with(input_row);
608            let temp_storage = RowArena::new();
609            let datums = exprs
610                .iter()
611                .map(|expr| expr.eval(&datums, &temp_storage))
612                .collect::<Result<Vec<_>, _>>()?;
613            let mut output_rows = func.eval(&datums, &temp_storage)?.fuse();
614            for (output_row, diff2) in (&mut output_rows).take(limit - new_rows.len()) {
615                let mut packer = row_buf.packer();
616                packer.extend_by_row(input_row);
617                packer.extend_by_row(&output_row);
618                new_rows.push((row_buf.clone(), diff2 * *diff))
619            }
620            // If we still have records to enumerate, but dropped out of the iteration,
621            // it means we have exhausted `limit` and should stop.
622            if output_rows.next() != None {
623                return Ok(None);
624            }
625        }
626        Ok(Some(new_rows))
627    }
628
629    fn fold_filter_constant(
630        predicates: &[MirScalarExpr],
631        rows: &[(Row, Diff)],
632    ) -> Result<Vec<(Row, Diff)>, EvalError> {
633        let mut new_rows = Vec::new();
634        let mut datum_vec = mz_repr::DatumVec::new();
635        'outer: for (row, diff) in rows {
636            let datums = datum_vec.borrow_with(row);
637            let temp_storage = RowArena::new();
638            for p in &*predicates {
639                if p.eval(&datums, &temp_storage)? != Datum::True {
640                    continue 'outer;
641                }
642            }
643            new_rows.push((row.clone(), *diff))
644        }
645        Ok(new_rows)
646    }
647}