mz_transform/
fold_constants.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Replace operators on constants collections with constant collections.
11
12use std::cmp::Ordering;
13use std::collections::{BTreeMap, BTreeSet};
14use std::convert::TryInto;
15use std::iter;
16
17use mz_expr::visit::Visit;
18use mz_expr::{
19    AggregateExpr, ColumnOrder, EvalError, MirRelationExpr, MirScalarExpr, TableFunc, UnaryFunc,
20};
21use mz_repr::{Datum, Diff, Row, RowArena, SqlRelationType};
22
23use crate::{TransformCtx, TransformError, any};
24
25/// Replace operators on constant collections with constant collections.
26#[derive(Debug)]
27pub struct FoldConstants {
28    /// An optional maximum size, after which optimization can cease.
29    ///
30    /// The `None` value here indicates no maximum size, but does not
31    /// currently guarantee that any constant expression will be reduced
32    /// to a `MirRelationExpr::Constant` variant.
33    pub limit: Option<usize>,
34}
35
36impl crate::Transform for FoldConstants {
37    fn name(&self) -> &'static str {
38        "FoldConstants"
39    }
40
41    #[mz_ore::instrument(
42        target = "optimizer",
43        level = "debug",
44        fields(path.segment = "fold_constants")
45    )]
46    fn actually_perform_transform(
47        &self,
48        relation: &mut MirRelationExpr,
49        _: &mut TransformCtx,
50    ) -> Result<(), TransformError> {
51        let mut type_stack = Vec::new();
52        let result = relation.try_visit_mut_post(&mut |e| -> Result<(), TransformError> {
53            let num_inputs = e.num_inputs();
54            let input_types = &type_stack[type_stack.len() - num_inputs..];
55            let mut relation_type = e.typ_with_input_types(input_types);
56            self.action(e, &mut relation_type)?;
57            type_stack.truncate(type_stack.len() - num_inputs);
58            type_stack.push(relation_type);
59            Ok(())
60        });
61        mz_repr::explain::trace_plan(&*relation);
62        result
63    }
64}
65
66impl FoldConstants {
67    /// Replace operators on constants collections with constant collections.
68    ///
69    /// This transform will cease optimization if it encounters constant collections
70    /// that are larger than `self.limit`, if that is set. It is not guaranteed that
71    /// a constant input within the limit will be reduced to a `Constant` variant.
72    pub fn action(
73        &self,
74        relation: &mut MirRelationExpr,
75        relation_type: &mut SqlRelationType,
76    ) -> Result<(), TransformError> {
77        match relation {
78            MirRelationExpr::Constant { .. } => { /* handled after match */ }
79            MirRelationExpr::Get { .. } => {}
80            MirRelationExpr::Let { .. } | MirRelationExpr::LetRec { .. } => {
81                // Constant propagation through bindings is currently handled by in NormalizeLets.
82                // Maybe we should move it / replicate it here (see database-issues#5346 for context)?
83            }
84            MirRelationExpr::Reduce {
85                input,
86                group_key,
87                aggregates,
88                monotonic: _,
89                expected_group_size: _,
90            } => {
91                // Guard against evaluating an expression that may contain
92                // unmaterializable functions.
93                if group_key.iter().any(|e| e.contains_unmaterializable())
94                    || aggregates
95                        .iter()
96                        .any(|a| a.expr.contains_unmaterializable())
97                {
98                    return Ok(());
99                }
100
101                if let Some((rows, ..)) = (**input).as_const() {
102                    let new_rows = match rows {
103                        Ok(rows) => {
104                            if let Some(rows) =
105                                Self::fold_reduce_constant(group_key, aggregates, rows, self.limit)
106                            {
107                                rows
108                            } else {
109                                return Ok(());
110                            }
111                        }
112                        Err(e) => Err(e.clone()),
113                    };
114                    *relation = MirRelationExpr::Constant {
115                        rows: new_rows,
116                        typ: relation_type.clone(),
117                    };
118                }
119            }
120            MirRelationExpr::TopK {
121                input,
122                group_key,
123                order_key,
124                limit,
125                offset,
126                ..
127            } => {
128                // Only fold constants when:
129                //
130                // 1. The `limit` value is not set, or
131                // 2. The `limit` value is set to a literal x such that x >= 0.
132                //
133                // We can improve this to arbitrary expressions, but it requires
134                // more typing.
135                if any![
136                    limit.is_none(),
137                    limit.as_ref().and_then(|l| l.as_literal_int64()) >= Some(0),
138                ] {
139                    let limit = limit
140                        .as_ref()
141                        .and_then(|l| l.as_literal_int64().map(Into::into));
142                    if let Some((rows, ..)) = (**input).as_const_mut() {
143                        if let Ok(rows) = rows {
144                            Self::fold_topk_constant(group_key, order_key, &limit, offset, rows);
145                        }
146                        *relation = input.take_dangerous();
147                    }
148                }
149            }
150            MirRelationExpr::Negate { input } => {
151                if let Some((rows, ..)) = (**input).as_const_mut() {
152                    if let Ok(rows) = rows {
153                        for (_row, diff) in rows {
154                            *diff = -*diff;
155                        }
156                    }
157                    *relation = input.take_dangerous();
158                }
159            }
160            MirRelationExpr::Threshold { input } => {
161                if let Some((rows, ..)) = (**input).as_const_mut() {
162                    if let Ok(rows) = rows {
163                        rows.retain(|(_, diff)| diff.is_positive());
164                    }
165                    *relation = input.take_dangerous();
166                }
167            }
168            MirRelationExpr::Map { input, scalars } => {
169                // Guard against evaluating expression that may contain
170                // unmaterializable functions.
171                if scalars.iter().any(|e| e.contains_unmaterializable()) {
172                    return Ok(());
173                }
174
175                if let Some((rows, ..)) = (**input).as_const() {
176                    // Do not evaluate calls if:
177                    // 1. The input consist of at least one row, and
178                    // 2. The scalars is a singleton mz_panic('forced panic') call.
179                    // Instead, indicate to the caller to panic.
180                    if rows.as_ref().map_or(0, |r| r.len()) > 0 && scalars.len() == 1 {
181                        if let MirScalarExpr::CallUnary {
182                            func: UnaryFunc::Panic(_),
183                            expr,
184                        } = &scalars[0]
185                        {
186                            if let Some("forced panic") = expr.as_literal_str() {
187                                let msg = "forced panic".to_string();
188                                return Err(TransformError::CallerShouldPanic(msg));
189                            }
190                        }
191                    }
192
193                    let new_rows = match rows {
194                        Ok(rows) => rows
195                            .iter()
196                            .map(|(input_row, diff)| {
197                                // TODO: reduce allocations to zero.
198                                let mut unpacked = input_row.unpack();
199                                let temp_storage = RowArena::new();
200                                for scalar in scalars.iter() {
201                                    unpacked.push(scalar.eval(&unpacked, &temp_storage)?)
202                                }
203                                Ok::<_, EvalError>((Row::pack_slice(&unpacked), *diff))
204                            })
205                            .collect::<Result<_, _>>(),
206                        Err(e) => Err(e.clone()),
207                    };
208                    *relation = MirRelationExpr::Constant {
209                        rows: new_rows,
210                        typ: relation_type.clone(),
211                    };
212                }
213            }
214            MirRelationExpr::FlatMap { input, func, exprs } => {
215                // Guard against evaluating expression that may contain unmaterializable functions.
216                if exprs.iter().any(|e| e.contains_unmaterializable()) {
217                    return Ok(());
218                }
219
220                if let Some((rows, ..)) = (**input).as_const() {
221                    let new_rows = match rows {
222                        Ok(rows) => Self::fold_flat_map_constant(func, exprs, rows, self.limit),
223                        Err(e) => Err(e.clone()),
224                    };
225                    match new_rows {
226                        Ok(None) => {}
227                        Ok(Some(rows)) => {
228                            *relation = MirRelationExpr::Constant {
229                                rows: Ok(rows),
230                                typ: relation_type.clone(),
231                            };
232                        }
233                        Err(err) => {
234                            *relation = MirRelationExpr::Constant {
235                                rows: Err(err),
236                                typ: relation_type.clone(),
237                            };
238                        }
239                    };
240                }
241            }
242            MirRelationExpr::Filter { input, predicates } => {
243                // Guard against evaluating expression that may contain
244                // unmaterializable function calls.
245                if predicates.iter().any(|e| e.contains_unmaterializable()) {
246                    return Ok(());
247                }
248
249                // If any predicate is false, reduce to the empty collection.
250                if predicates
251                    .iter()
252                    .any(|p| p.is_literal_false() || p.is_literal_null())
253                {
254                    relation.take_safely(Some(relation_type.clone()));
255                } else if let Some((rows, ..)) = (**input).as_const() {
256                    // Evaluate errors last, to reduce risk of spurious errors.
257                    predicates.sort_by_key(|p| p.is_literal_err());
258                    let new_rows = match rows {
259                        Ok(rows) => Self::fold_filter_constant(predicates, rows),
260                        Err(e) => Err(e.clone()),
261                    };
262                    *relation = MirRelationExpr::Constant {
263                        rows: new_rows,
264                        typ: relation_type.clone(),
265                    };
266                }
267            }
268            MirRelationExpr::Project { input, outputs } => {
269                if let Some((rows, ..)) = (**input).as_const() {
270                    let mut row_buf = Row::default();
271                    let new_rows = match rows {
272                        Ok(rows) => Ok(rows
273                            .iter()
274                            .map(|(input_row, diff)| {
275                                // TODO: reduce allocations to zero.
276                                let datums = input_row.unpack();
277                                row_buf.packer().extend(outputs.iter().map(|i| &datums[*i]));
278                                (row_buf.clone(), *diff)
279                            })
280                            .collect()),
281                        Err(e) => Err(e.clone()),
282                    };
283                    *relation = MirRelationExpr::Constant {
284                        rows: new_rows,
285                        typ: relation_type.clone(),
286                    };
287                }
288            }
289            MirRelationExpr::Join {
290                inputs,
291                equivalences,
292                ..
293            } => {
294                if inputs.iter().any(|e| e.is_empty()) {
295                    relation.take_safely(Some(relation_type.clone()));
296                } else if let Some(e) = inputs.iter().find_map(|i| i.as_const_err()) {
297                    *relation = MirRelationExpr::Constant {
298                        rows: Err(e.clone()),
299                        typ: relation_type.clone(),
300                    };
301                } else if inputs
302                    .iter()
303                    .all(|i| matches!(i.as_const(), Some((Ok(_), ..))))
304                {
305                    // Guard against evaluating expression that may contain unmaterializable functions.
306                    if equivalences
307                        .iter()
308                        .any(|equiv| equiv.iter().any(|e| e.contains_unmaterializable()))
309                    {
310                        return Ok(());
311                    }
312
313                    // We can fold all constant inputs together, but must apply the constraints to restrict them.
314                    // We start with a single 0-ary row.
315                    let mut old_rows = vec![(Row::pack::<_, Datum>(None), Diff::ONE)];
316                    let mut row_buf = Row::default();
317                    for input in inputs.iter() {
318                        if let Some((Ok(rows), ..)) = input.as_const() {
319                            if let Some(limit) = self.limit {
320                                if old_rows.len() * rows.len() > limit {
321                                    // Bail out if we have produced too many rows.
322                                    // TODO: progressively apply equivalences to narrow this count
323                                    // as we go, rather than at the end.
324                                    return Ok(());
325                                }
326                            }
327                            let mut next_rows = Vec::new();
328                            for (old_row, old_count) in old_rows {
329                                for (new_row, new_count) in rows.iter() {
330                                    let mut packer = row_buf.packer();
331                                    packer.extend_by_row(&old_row);
332                                    packer.extend_by_row(new_row);
333                                    next_rows.push((row_buf.clone(), old_count * *new_count));
334                                }
335                            }
336                            old_rows = next_rows;
337                        }
338                    }
339
340                    // Now throw away anything that doesn't satisfy the requisite constraints.
341                    let mut datum_vec = mz_repr::DatumVec::new();
342                    old_rows.retain(|(row, _count)| {
343                        let datums = datum_vec.borrow_with(row);
344                        let temp_storage = RowArena::new();
345                        equivalences.iter().all(|equivalence| {
346                            let mut values =
347                                equivalence.iter().map(|e| e.eval(&datums, &temp_storage));
348                            if let Some(value) = values.next() {
349                                values.all(|v| v == value)
350                            } else {
351                                true
352                            }
353                        })
354                    });
355
356                    *relation = MirRelationExpr::Constant {
357                        rows: Ok(old_rows),
358                        typ: relation_type.clone(),
359                    };
360                }
361                // TODO: General constant folding for all constant inputs.
362            }
363            MirRelationExpr::Union { base, inputs } => {
364                if let Some(e) = iter::once(&mut **base)
365                    .chain(&mut *inputs)
366                    .find_map(|i| i.as_const_err())
367                {
368                    *relation = MirRelationExpr::Constant {
369                        rows: Err(e.clone()),
370                        typ: relation_type.clone(),
371                    };
372                } else {
373                    let mut rows = vec![];
374                    let mut new_inputs = vec![];
375
376                    for input in iter::once(&mut **base).chain(&mut *inputs) {
377                        if let Some((Ok(rs), ..)) = input.as_const() {
378                            rows.extend(rs.clone());
379                        } else {
380                            new_inputs.push(input.clone())
381                        }
382                    }
383                    if !rows.is_empty() {
384                        new_inputs.push(MirRelationExpr::Constant {
385                            rows: Ok(rows),
386                            typ: relation_type.clone(),
387                        });
388                    }
389
390                    *relation = MirRelationExpr::union_many(new_inputs, relation_type.clone());
391                }
392            }
393            MirRelationExpr::ArrangeBy { .. } => {
394                // Don't fold ArrangeBys, because that could result in unarranged Delta join inputs.
395                // See also the comment on `MirRelationExpr::Constant`.
396            }
397        }
398
399        // This transformation maintains the invariant that all constant nodes
400        // will be consolidated. We have to make a separate check for constant
401        // nodes here, since the match arm above might install new constant
402        // nodes.
403        if let Some((Ok(rows), typ)) = relation.as_const_mut() {
404            // Reduce down to canonical representation.
405            differential_dataflow::consolidation::consolidate(rows);
406
407            // Re-establish nullability of each column.
408            for col_type in typ.column_types.iter_mut() {
409                col_type.nullable = false;
410            }
411            for (row, _) in rows.iter_mut() {
412                for (index, datum) in row.iter().enumerate() {
413                    if datum.is_null() {
414                        typ.column_types[index].nullable = true;
415                    }
416                }
417            }
418            *relation_type = typ.clone();
419        }
420
421        Ok(())
422    }
423
424    // TODO(benesch): remove this once this function no longer makes use of
425    // potentially dangerous `as` conversions.
426    #[allow(clippy::as_conversions)]
427    fn fold_reduce_constant(
428        group_key: &[MirScalarExpr],
429        aggregates: &[AggregateExpr],
430        rows: &[(Row, Diff)],
431        limit: Option<usize>,
432    ) -> Option<Result<Vec<(Row, Diff)>, EvalError>> {
433        // Build a map from `group_key` to `Vec<Vec<an, ..., a1>>)`,
434        // where `an` is the input to the nth aggregate function in
435        // `aggregates`.
436        let mut groups = BTreeMap::new();
437        let temp_storage2 = RowArena::new();
438        let mut row_buf = Row::default();
439        let mut limit_remaining =
440            limit.map_or(Diff::MAX, |limit| Diff::try_from(limit).expect("must fit"));
441        for (row, diff) in rows {
442            // We currently maintain the invariant that any negative
443            // multiplicities will be consolidated away before they
444            // arrive at a reduce.
445
446            if *diff <= Diff::ZERO {
447                return Some(Err(EvalError::InvalidParameterValue(
448                    "constant folding encountered reduce on collection with non-positive multiplicities".into()
449                )));
450            }
451
452            if limit_remaining < *diff {
453                return None;
454            }
455            limit_remaining -= diff;
456
457            let datums = row.unpack();
458            let temp_storage = RowArena::new();
459            let key = match group_key
460                .iter()
461                .map(|e| e.eval(&datums, &temp_storage2))
462                .collect::<Result<Vec<_>, _>>()
463            {
464                Ok(key) => key,
465                Err(e) => return Some(Err(e)),
466            };
467            let val = match aggregates
468                .iter()
469                .map(|agg| {
470                    row_buf
471                        .packer()
472                        .extend([agg.expr.eval(&datums, &temp_storage)?]);
473                    Ok::<_, EvalError>(row_buf.clone())
474                })
475                .collect::<Result<Vec<_>, _>>()
476            {
477                Ok(val) => val,
478                Err(e) => return Some(Err(e)),
479            };
480            let entry = groups.entry(key).or_insert_with(Vec::new);
481            for _ in 0..diff.into_inner() {
482                entry.push(val.clone());
483            }
484        }
485
486        // For each group, apply the aggregate function to the rows
487        // in the group. The output is
488        // `Vec<Vec<k1, ..., kn, r1, ..., rn>>`
489        // where kn is the nth column of the key and rn is the
490        // result of the nth aggregate function for that group.
491        let new_rows = groups
492            .into_iter()
493            .map({
494                let mut row_buf = Row::default();
495                move |(key, vals)| {
496                    let temp_storage = RowArena::new();
497                    row_buf.packer().extend(key.into_iter().chain(
498                        aggregates.iter().enumerate().map(|(i, agg)| {
499                            if agg.distinct {
500                                agg.func.eval(
501                                    vals.iter()
502                                        .map(|val| val[i].unpack_first())
503                                        .collect::<BTreeSet<_>>(),
504                                    &temp_storage,
505                                )
506                            } else {
507                                agg.func.eval(
508                                    vals.iter().map(|val| val[i].unpack_first()),
509                                    &temp_storage,
510                                )
511                            }
512                        }),
513                    ));
514                    (row_buf.clone(), Diff::ONE)
515                }
516            })
517            .collect();
518        Some(Ok(new_rows))
519    }
520
521    fn fold_topk_constant<'a>(
522        group_key: &[usize],
523        order_key: &[ColumnOrder],
524        limit: &Option<Diff>,
525        offset: &usize,
526        rows: &'a mut [(Row, Diff)],
527    ) {
528        // helper functions for comparing elements by order_key and group_key
529        let mut lhs_datum_vec = mz_repr::DatumVec::new();
530        let mut rhs_datum_vec = mz_repr::DatumVec::new();
531        let mut cmp_order_key = |lhs: &(Row, Diff), rhs: &(Row, Diff)| {
532            let lhs_datums = &lhs_datum_vec.borrow_with(&lhs.0);
533            let rhs_datums = &rhs_datum_vec.borrow_with(&rhs.0);
534            mz_expr::compare_columns(order_key, lhs_datums, rhs_datums, || lhs.cmp(rhs))
535        };
536        let mut cmp_group_key = {
537            let group_key = group_key
538                .iter()
539                .map(|column| ColumnOrder {
540                    column: *column,
541                    // desc and nulls_last don't matter: the sorting by cmp_group_key is just to
542                    // make the elements of each group appear next to each other, but the order of
543                    // groups doesn't matter.
544                    desc: false,
545                    nulls_last: false,
546                })
547                .collect::<Vec<ColumnOrder>>();
548            let mut lhs_datum_vec = mz_repr::DatumVec::new();
549            let mut rhs_datum_vec = mz_repr::DatumVec::new();
550            move |lhs: &(Row, Diff), rhs: &(Row, Diff)| {
551                let lhs_datums = &lhs_datum_vec.borrow_with(&lhs.0);
552                let rhs_datums = &rhs_datum_vec.borrow_with(&rhs.0);
553                mz_expr::compare_columns(&group_key, lhs_datums, rhs_datums, || Ordering::Equal)
554            }
555        };
556
557        // compute Ordering based on the sort_key, otherwise consider all rows equal
558        rows.sort_by(&mut cmp_order_key);
559
560        // sort by the grouping key if not empty, keeping order_key as a secondary sort
561        if !group_key.is_empty() {
562            rows.sort_by(&mut cmp_group_key);
563        };
564
565        let mut same_group_key =
566            |lhs: &(Row, Diff), rhs: &(Row, Diff)| cmp_group_key(lhs, rhs) == Ordering::Equal;
567
568        let mut cursor = 0;
569        while cursor < rows.len() {
570            // first, reset the remaining limit and offset for the current group
571            let mut offset_rem: Diff = offset.clone().try_into().unwrap();
572            let mut limit_rem: Option<Diff> = limit.clone();
573
574            let mut finger = cursor;
575            while finger < rows.len() && same_group_key(&rows[cursor], &rows[finger]) {
576                if rows[finger].1.is_negative() {
577                    // ignore elements with negative diff
578                    rows[finger].1 = Diff::ZERO;
579                } else {
580                    // determine how many of the leading rows to ignore,
581                    // then decrement the diff and remaining offset by that number
582                    let rows_to_ignore = std::cmp::min(offset_rem, rows[finger].1);
583                    rows[finger].1 -= rows_to_ignore;
584                    offset_rem -= rows_to_ignore;
585                    // determine how many of the remaining rows to retain,
586                    // then update the diff and decrement the remaining limit by that number
587                    if let Some(limit_rem) = &mut limit_rem {
588                        let rows_to_retain = std::cmp::min(*limit_rem, rows[finger].1);
589                        rows[finger].1 = rows_to_retain;
590                        *limit_rem -= rows_to_retain;
591                    }
592                }
593                finger += 1;
594            }
595            cursor = finger;
596        }
597    }
598
599    fn fold_flat_map_constant(
600        func: &TableFunc,
601        exprs: &[MirScalarExpr],
602        rows: &[(Row, Diff)],
603        limit: Option<usize>,
604    ) -> Result<Option<Vec<(Row, Diff)>>, EvalError> {
605        // We cannot exceed `usize::MAX` in any array, so this is a fine upper bound.
606        let limit = limit.unwrap_or(usize::MAX);
607        let mut new_rows = Vec::new();
608        let mut row_buf = Row::default();
609        let mut datum_vec = mz_repr::DatumVec::new();
610        for (input_row, diff) in rows {
611            let datums = datum_vec.borrow_with(input_row);
612            let temp_storage = RowArena::new();
613            let datums = exprs
614                .iter()
615                .map(|expr| expr.eval(&datums, &temp_storage))
616                .collect::<Result<Vec<_>, _>>()?;
617            let mut output_rows = func.eval(&datums, &temp_storage)?.fuse();
618            for (output_row, diff2) in (&mut output_rows).take(limit - new_rows.len()) {
619                let mut packer = row_buf.packer();
620                packer.extend_by_row(input_row);
621                packer.extend_by_row(&output_row);
622                new_rows.push((row_buf.clone(), diff2 * *diff))
623            }
624            // If we still have records to enumerate, but dropped out of the iteration,
625            // it means we have exhausted `limit` and should stop.
626            if output_rows.next() != None {
627                return Ok(None);
628            }
629        }
630        Ok(Some(new_rows))
631    }
632
633    fn fold_filter_constant(
634        predicates: &[MirScalarExpr],
635        rows: &[(Row, Diff)],
636    ) -> Result<Vec<(Row, Diff)>, EvalError> {
637        let mut new_rows = Vec::new();
638        let mut datum_vec = mz_repr::DatumVec::new();
639        'outer: for (row, diff) in rows {
640            let datums = datum_vec.borrow_with(row);
641            let temp_storage = RowArena::new();
642            for p in &*predicates {
643                if p.eval(&datums, &temp_storage)? != Datum::True {
644                    continue 'outer;
645                }
646            }
647            new_rows.push((row.clone(), *diff))
648        }
649        Ok(new_rows)
650    }
651}