mz_transform/fold_constants.rs
1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Replace operators on constants collections with constant collections.
11
12use std::cmp::Ordering;
13use std::collections::{BTreeMap, BTreeSet};
14use std::convert::TryInto;
15use std::iter;
16
17use mz_expr::visit::Visit;
18use mz_expr::{
19 AggregateExpr, ColumnOrder, Eval, EvalError, MirRelationExpr, MirScalarExpr, RowComparator,
20 TableFunc, UnaryFunc,
21};
22use mz_repr::{Datum, Diff, ReprRelationType, Row, RowArena};
23
24use crate::{TransformCtx, TransformError, any};
25
26/// Replace operators on constant collections with constant collections.
27#[derive(Debug)]
28pub struct FoldConstants {
29 /// An optional maximum size, after which optimization can cease.
30 ///
31 /// The `None` value here indicates no maximum size, but does not
32 /// currently guarantee that any constant expression will be reduced
33 /// to a `MirRelationExpr::Constant` variant.
34 pub limit: Option<usize>,
35}
36
37impl crate::Transform for FoldConstants {
38 fn name(&self) -> &'static str {
39 "FoldConstants"
40 }
41
42 #[mz_ore::instrument(
43 target = "optimizer",
44 level = "debug",
45 fields(path.segment = "fold_constants")
46 )]
47 fn actually_perform_transform(
48 &self,
49 relation: &mut MirRelationExpr,
50 _: &mut TransformCtx,
51 ) -> Result<(), TransformError> {
52 let mut type_stack = Vec::new();
53 let result = relation.try_visit_mut_post(&mut |e| -> Result<(), TransformError> {
54 let num_inputs = e.num_inputs();
55 let input_types = &type_stack[type_stack.len() - num_inputs..];
56 let mut relation_type = e.typ_with_input_types(input_types);
57 self.action(e, &mut relation_type)?;
58 type_stack.truncate(type_stack.len() - num_inputs);
59 type_stack.push(relation_type);
60 Ok(())
61 });
62 mz_repr::explain::trace_plan(&*relation);
63 result
64 }
65}
66
67impl FoldConstants {
68 /// Replace operators on constants collections with constant collections.
69 ///
70 /// This transform will cease optimization if it encounters constant collections
71 /// that are larger than `self.limit`, if that is set. It is not guaranteed that
72 /// a constant input within the limit will be reduced to a `Constant` variant.
73 pub fn action(
74 &self,
75 relation: &mut MirRelationExpr,
76 relation_type: &mut ReprRelationType,
77 ) -> Result<(), TransformError> {
78 match relation {
79 MirRelationExpr::Constant { .. } => { /* handled after match */ }
80 MirRelationExpr::Get { .. } => {}
81 MirRelationExpr::Let { .. } | MirRelationExpr::LetRec { .. } => {
82 // Constant propagation through bindings is currently handled by in NormalizeLets.
83 // Maybe we should move it / replicate it here (see database-issues#5346 for context)?
84 }
85 MirRelationExpr::Reduce {
86 input,
87 group_key,
88 aggregates,
89 monotonic: _,
90 expected_group_size: _,
91 } => {
92 // Guard against evaluating an expression that may contain
93 // unmaterializable functions.
94 if group_key.iter().any(|e| e.contains_unmaterializable())
95 || aggregates
96 .iter()
97 .any(|a| a.expr.contains_unmaterializable())
98 {
99 return Ok(());
100 }
101
102 if let Some((rows, ..)) = (**input).as_const() {
103 let new_rows = match rows {
104 Ok(rows) => {
105 if let Some(rows) =
106 Self::fold_reduce_constant(group_key, aggregates, rows, self.limit)
107 {
108 rows
109 } else {
110 return Ok(());
111 }
112 }
113 Err(e) => Err(e.clone()),
114 };
115 *relation = MirRelationExpr::Constant {
116 rows: new_rows,
117 typ: relation_type.clone(),
118 };
119 }
120 }
121 MirRelationExpr::TopK {
122 input,
123 group_key,
124 order_key,
125 limit,
126 offset,
127 ..
128 } => {
129 // Only fold constants when:
130 //
131 // 1. The `limit` value is not set, or
132 // 2. The `limit` value is set to a literal x such that x >= 0.
133 //
134 // We can improve this to arbitrary expressions, but it requires
135 // more typing.
136 if any![
137 limit.is_none(),
138 limit.as_ref().and_then(|l| l.as_literal_int64()) >= Some(0),
139 ] {
140 let limit = limit
141 .as_ref()
142 .and_then(|l| l.as_literal_int64().map(Into::into));
143 if let Some((rows, ..)) = (**input).as_const_mut() {
144 if let Ok(rows) = rows {
145 Self::fold_topk_constant(group_key, order_key, &limit, offset, rows);
146 }
147 *relation = input.take_dangerous();
148 }
149 }
150 }
151 MirRelationExpr::Negate { input } => {
152 if let Some((rows, ..)) = (**input).as_const_mut() {
153 if let Ok(rows) = rows {
154 for (_row, diff) in rows {
155 *diff = -*diff;
156 }
157 }
158 *relation = input.take_dangerous();
159 }
160 }
161 MirRelationExpr::Threshold { input } => {
162 if let Some((rows, ..)) = (**input).as_const_mut() {
163 if let Ok(rows) = rows {
164 rows.retain(|(_, diff)| diff.is_positive());
165 }
166 *relation = input.take_dangerous();
167 }
168 }
169 MirRelationExpr::Map { input, scalars } => {
170 // Guard against evaluating expression that may contain
171 // unmaterializable functions.
172 if scalars.iter().any(|e| e.contains_unmaterializable()) {
173 return Ok(());
174 }
175
176 if let Some((rows, ..)) = (**input).as_const() {
177 // Do not evaluate calls if:
178 // 1. The input consist of at least one row, and
179 // 2. The scalars is a singleton mz_panic('forced panic') call.
180 // Instead, indicate to the caller to panic.
181 if rows.as_ref().map_or(0, |r| r.len()) > 0 && scalars.len() == 1 {
182 if let MirScalarExpr::CallUnary {
183 func: UnaryFunc::Panic(_),
184 expr,
185 } = &scalars[0]
186 {
187 if let Some("forced panic") = expr.as_literal_str() {
188 let msg = "forced panic".to_string();
189 return Err(TransformError::CallerShouldPanic(msg));
190 }
191 }
192 }
193
194 let new_rows = match rows {
195 Ok(rows) => rows
196 .iter()
197 .map(|(input_row, diff)| {
198 // TODO: reduce allocations to zero.
199 let mut unpacked = input_row.unpack();
200 let temp_storage = RowArena::new();
201 for scalar in scalars.iter() {
202 unpacked.push(scalar.eval(&unpacked, &temp_storage)?)
203 }
204 Ok::<_, EvalError>((Row::pack_slice(&unpacked), *diff))
205 })
206 .collect::<Result<_, _>>(),
207 Err(e) => Err(e.clone()),
208 };
209 *relation = MirRelationExpr::Constant {
210 rows: new_rows,
211 typ: relation_type.clone(),
212 };
213 }
214 }
215 MirRelationExpr::FlatMap { input, func, exprs } => {
216 // Guard against evaluating expression that may contain unmaterializable functions.
217 if exprs.iter().any(|e| e.contains_unmaterializable()) {
218 return Ok(());
219 }
220
221 if let Some((rows, ..)) = (**input).as_const() {
222 let new_rows = match rows {
223 Ok(rows) => Self::fold_flat_map_constant(func, exprs, rows, self.limit),
224 Err(e) => Err(e.clone()),
225 };
226 match new_rows {
227 Ok(None) => {}
228 Ok(Some(rows)) => {
229 *relation = MirRelationExpr::Constant {
230 rows: Ok(rows),
231 typ: relation_type.clone(),
232 };
233 }
234 Err(err) => {
235 *relation = MirRelationExpr::Constant {
236 rows: Err(err),
237 typ: relation_type.clone(),
238 };
239 }
240 };
241 }
242 }
243 MirRelationExpr::Filter { input, predicates } => {
244 // Guard against evaluating expression that may contain
245 // unmaterializable function calls.
246 if predicates.iter().any(|e| e.contains_unmaterializable()) {
247 return Ok(());
248 }
249
250 // If any predicate is false, reduce to the empty collection.
251 if predicates
252 .iter()
253 .any(|p| p.is_literal_false() || p.is_literal_null())
254 {
255 relation.take_safely(Some(relation_type.clone()));
256 } else if let Some((rows, ..)) = (**input).as_const() {
257 // Evaluate errors last, to reduce risk of spurious errors.
258 predicates.sort_by_key(|p| p.is_literal_err());
259 let new_rows = match rows {
260 Ok(rows) => Self::fold_filter_constant(predicates, rows),
261 Err(e) => Err(e.clone()),
262 };
263 *relation = MirRelationExpr::Constant {
264 rows: new_rows,
265 typ: relation_type.clone(),
266 };
267 }
268 }
269 MirRelationExpr::Project { input, outputs } => {
270 if let Some((rows, ..)) = (**input).as_const() {
271 let mut row_buf = Row::default();
272 let new_rows = match rows {
273 Ok(rows) => Ok(rows
274 .iter()
275 .map(|(input_row, diff)| {
276 // TODO: reduce allocations to zero.
277 let datums = input_row.unpack();
278 row_buf.packer().extend(outputs.iter().map(|i| &datums[*i]));
279 (row_buf.clone(), *diff)
280 })
281 .collect()),
282 Err(e) => Err(e.clone()),
283 };
284 *relation = MirRelationExpr::Constant {
285 rows: new_rows,
286 typ: relation_type.clone(),
287 };
288 }
289 }
290 MirRelationExpr::Join {
291 inputs,
292 equivalences,
293 ..
294 } => {
295 if inputs.iter().any(|e| e.is_empty()) {
296 relation.take_safely(Some(relation_type.clone()));
297 } else if let Some(e) = inputs.iter().find_map(|i| i.as_const_err()) {
298 *relation = MirRelationExpr::Constant {
299 rows: Err(e.clone()),
300 typ: relation_type.clone(),
301 };
302 } else if inputs
303 .iter()
304 .all(|i| matches!(i.as_const(), Some((Ok(_), ..))))
305 {
306 // Guard against evaluating expression that may contain unmaterializable functions.
307 if equivalences
308 .iter()
309 .any(|equiv| equiv.iter().any(|e| e.contains_unmaterializable()))
310 {
311 return Ok(());
312 }
313
314 // We can fold all constant inputs together, but must apply the constraints to restrict them.
315 // We start with a single 0-ary row.
316 let mut old_rows = vec![(Row::pack::<_, Datum>(None), Diff::ONE)];
317 let mut row_buf = Row::default();
318 for input in inputs.iter() {
319 if let Some((Ok(rows), ..)) = input.as_const() {
320 if let Some(limit) = self.limit {
321 if old_rows.len() * rows.len() > limit {
322 // Bail out if we have produced too many rows.
323 // TODO: progressively apply equivalences to narrow this count
324 // as we go, rather than at the end.
325 return Ok(());
326 }
327 }
328 let mut next_rows = Vec::new();
329 for (old_row, old_count) in old_rows {
330 for (new_row, new_count) in rows.iter() {
331 let mut packer = row_buf.packer();
332 packer.extend_by_row(&old_row);
333 packer.extend_by_row(new_row);
334 next_rows.push((row_buf.clone(), old_count * *new_count));
335 }
336 }
337 old_rows = next_rows;
338 }
339 }
340
341 // Now throw away anything that doesn't satisfy the requisite constraints.
342 let mut datum_vec = mz_repr::DatumVec::new();
343 old_rows.retain(|(row, _count)| {
344 let datums = datum_vec.borrow_with(row);
345 let temp_storage = RowArena::new();
346 equivalences.iter().all(|equivalence| {
347 let mut values =
348 equivalence.iter().map(|e| e.eval(&datums, &temp_storage));
349 if let Some(value) = values.next() {
350 values.all(|v| v == value)
351 } else {
352 true
353 }
354 })
355 });
356
357 *relation = MirRelationExpr::Constant {
358 rows: Ok(old_rows),
359 typ: relation_type.clone(),
360 };
361 }
362 // TODO: General constant folding for all constant inputs.
363 }
364 MirRelationExpr::Union { base, inputs } => {
365 if let Some(e) = iter::once(&mut **base)
366 .chain(&mut *inputs)
367 .find_map(|i| i.as_const_err())
368 {
369 *relation = MirRelationExpr::Constant {
370 rows: Err(e.clone()),
371 typ: relation_type.clone(),
372 };
373 } else {
374 let mut rows = vec![];
375 let mut new_inputs = vec![];
376
377 for input in iter::once(&mut **base).chain(&mut *inputs) {
378 if let Some((Ok(rs), ..)) = input.as_const() {
379 rows.extend(rs.clone());
380 } else {
381 new_inputs.push(input.clone())
382 }
383 }
384 if !rows.is_empty() {
385 new_inputs.push(MirRelationExpr::Constant {
386 rows: Ok(rows),
387 typ: relation_type.clone(),
388 });
389 }
390
391 *relation = MirRelationExpr::union_many(new_inputs, relation_type.clone());
392 }
393 }
394 MirRelationExpr::ArrangeBy { .. } => {
395 // Don't fold ArrangeBys, because that could result in unarranged Delta join inputs.
396 // See also the comment on `MirRelationExpr::Constant`.
397 }
398 }
399
400 // This transformation maintains the invariant that all constant nodes
401 // will be consolidated. We have to make a separate check for constant
402 // nodes here, since the match arm above might install new constant
403 // nodes.
404 if let Some((Ok(rows), typ)) = relation.as_const_mut() {
405 // Reduce down to canonical representation.
406 differential_dataflow::consolidation::consolidate(rows);
407
408 // Re-establish nullability of each column.
409 for col_type in typ.column_types.iter_mut() {
410 col_type.nullable = false;
411 }
412 for (row, _) in rows.iter_mut() {
413 for (index, datum) in row.iter().enumerate() {
414 if datum.is_null() {
415 typ.column_types[index].nullable = true;
416 }
417 }
418 }
419 *relation_type = typ.clone();
420 }
421
422 Ok(())
423 }
424
425 // TODO(benesch): remove this once this function no longer makes use of
426 // potentially dangerous `as` conversions.
427 #[allow(clippy::as_conversions)]
428 fn fold_reduce_constant(
429 group_key: &[MirScalarExpr],
430 aggregates: &[AggregateExpr],
431 rows: &[(Row, Diff)],
432 limit: Option<usize>,
433 ) -> Option<Result<Vec<(Row, Diff)>, EvalError>> {
434 // Build a map from `group_key` to `Vec<Vec<an, ..., a1>>)`,
435 // where `an` is the input to the nth aggregate function in
436 // `aggregates`.
437 let mut groups = BTreeMap::new();
438 let temp_storage2 = RowArena::new();
439 let mut row_buf = Row::default();
440 let mut limit_remaining =
441 limit.map_or(Diff::MAX, |limit| Diff::try_from(limit).expect("must fit"));
442 for (row, diff) in rows {
443 // We currently maintain the invariant that any negative
444 // multiplicities will be consolidated away before they
445 // arrive at a reduce.
446
447 if *diff <= Diff::ZERO {
448 return Some(Err(EvalError::InvalidParameterValue(
449 "constant folding encountered reduce on collection with non-positive multiplicities".into()
450 )));
451 }
452
453 if limit_remaining < *diff {
454 return None;
455 }
456 limit_remaining -= diff;
457
458 let datums = row.unpack();
459 let temp_storage = RowArena::new();
460 let key = match group_key
461 .iter()
462 .map(|e| e.eval(&datums, &temp_storage2))
463 .collect::<Result<Vec<_>, _>>()
464 {
465 Ok(key) => key,
466 Err(e) => return Some(Err(e)),
467 };
468 let val = match aggregates
469 .iter()
470 .map(|agg| {
471 row_buf
472 .packer()
473 .extend([agg.expr.eval(&datums, &temp_storage)?]);
474 Ok::<_, EvalError>(row_buf.clone())
475 })
476 .collect::<Result<Vec<_>, _>>()
477 {
478 Ok(val) => val,
479 Err(e) => return Some(Err(e)),
480 };
481 // Store the value once alongside its multiplicity rather than
482 // expanding it into `diff` copies. The count-aware aggregate
483 // `eval` consumes the multiplicity directly, so this avoids
484 // materializing a number of rows proportional to the diff.
485 let entry = groups.entry(key).or_insert_with(Vec::new);
486 entry.push((val, *diff));
487 }
488
489 // For each group, apply the aggregate function to the rows
490 // in the group. The output is
491 // `Vec<Vec<k1, ..., kn, r1, ..., rn>>`
492 // where kn is the nth column of the key and rn is the
493 // result of the nth aggregate function for that group.
494 let new_rows = groups
495 .into_iter()
496 .map({
497 let mut row_buf = Row::default();
498 move |(key, vals)| {
499 let temp_storage = RowArena::new();
500 row_buf.packer().extend(key.into_iter().chain(
501 aggregates.iter().enumerate().map(|(i, agg)| {
502 if agg.distinct {
503 // Distinct collapses to one copy per value, so
504 // each distinct datum gets a unit multiplicity.
505 agg.func.eval(
506 vals.iter()
507 .map(|(val, _diff)| val[i].unpack_first())
508 .collect::<BTreeSet<_>>()
509 .into_iter()
510 .map(|datum| (datum, Diff::ONE)),
511 &temp_storage,
512 )
513 } else {
514 agg.func.eval(
515 vals.iter()
516 .map(|(val, diff)| (val[i].unpack_first(), *diff)),
517 &temp_storage,
518 )
519 }
520 }),
521 ));
522 (row_buf.clone(), Diff::ONE)
523 }
524 })
525 .collect();
526 Some(Ok(new_rows))
527 }
528
529 fn fold_topk_constant<'a>(
530 group_key: &[usize],
531 order_key: &[ColumnOrder],
532 limit: &Option<Diff>,
533 offset: &usize,
534 rows: &'a mut [(Row, Diff)],
535 ) {
536 // helper functions for comparing elements by order_key and group_key
537 let comparator = RowComparator::new(order_key);
538
539 let mut cmp_order_key = |lhs: &(Row, Diff), rhs: &(Row, Diff)| {
540 comparator.compare_rows(&lhs.0, &rhs.0, || lhs.cmp(rhs))
541 };
542 let mut cmp_group_key = {
543 let group_key = group_key
544 .iter()
545 .map(|column| ColumnOrder {
546 column: *column,
547 // desc and nulls_last don't matter: the sorting by cmp_group_key is just to
548 // make the elements of each group appear next to each other, but the order of
549 // groups doesn't matter.
550 desc: false,
551 nulls_last: false,
552 })
553 .collect::<Vec<ColumnOrder>>();
554 let comparator = RowComparator::new(group_key);
555 move |lhs: &(Row, Diff), rhs: &(Row, Diff)| {
556 comparator.compare_rows(&lhs.0, &rhs.0, || Ordering::Equal)
557 }
558 };
559
560 // compute Ordering based on the sort_key, otherwise consider all rows equal
561 rows.sort_by(&mut cmp_order_key);
562
563 // sort by the grouping key if not empty, keeping order_key as a secondary sort
564 if !group_key.is_empty() {
565 rows.sort_by(&mut cmp_group_key);
566 };
567
568 let same_group_key =
569 |lhs: &(Row, Diff), rhs: &(Row, Diff)| cmp_group_key(lhs, rhs) == Ordering::Equal;
570
571 let mut cursor = 0;
572 while cursor < rows.len() {
573 // first, reset the remaining limit and offset for the current group
574 let mut offset_rem: Diff = offset.clone().try_into().unwrap();
575 let mut limit_rem: Option<Diff> = limit.clone();
576
577 let mut finger = cursor;
578 while finger < rows.len() && same_group_key(&rows[cursor], &rows[finger]) {
579 if rows[finger].1.is_negative() {
580 // ignore elements with negative diff
581 rows[finger].1 = Diff::ZERO;
582 } else {
583 // determine how many of the leading rows to ignore,
584 // then decrement the diff and remaining offset by that number
585 let rows_to_ignore = std::cmp::min(offset_rem, rows[finger].1);
586 rows[finger].1 -= rows_to_ignore;
587 offset_rem -= rows_to_ignore;
588 // determine how many of the remaining rows to retain,
589 // then update the diff and decrement the remaining limit by that number
590 if let Some(limit_rem) = &mut limit_rem {
591 let rows_to_retain = std::cmp::min(*limit_rem, rows[finger].1);
592 rows[finger].1 = rows_to_retain;
593 *limit_rem -= rows_to_retain;
594 }
595 }
596 finger += 1;
597 }
598 cursor = finger;
599 }
600 }
601
602 fn fold_flat_map_constant(
603 func: &TableFunc,
604 exprs: &[MirScalarExpr],
605 rows: &[(Row, Diff)],
606 limit: Option<usize>,
607 ) -> Result<Option<Vec<(Row, Diff)>>, EvalError> {
608 // We cannot exceed `usize::MAX` in any array, so this is a fine upper bound.
609 let limit = limit.unwrap_or(usize::MAX);
610 let mut new_rows = Vec::new();
611 let mut row_buf = Row::default();
612 let mut datum_vec = mz_repr::DatumVec::new();
613 for (input_row, diff) in rows {
614 let datums = datum_vec.borrow_with(input_row);
615 let temp_storage = RowArena::new();
616 let datums = exprs
617 .iter()
618 .map(|expr| expr.eval(&datums, &temp_storage))
619 .collect::<Result<Vec<_>, _>>()?;
620 let mut output_rows = func.eval(&datums, &temp_storage)?.fuse();
621 for (output_row, diff2) in (&mut output_rows).take(limit - new_rows.len()) {
622 let mut packer = row_buf.packer();
623 packer.extend_by_row(input_row);
624 packer.extend_by_row(&output_row);
625 new_rows.push((row_buf.clone(), diff2 * *diff))
626 }
627 // If we still have records to enumerate, but dropped out of the iteration,
628 // it means we have exhausted `limit` and should stop.
629 if output_rows.next() != None {
630 return Ok(None);
631 }
632 }
633 Ok(Some(new_rows))
634 }
635
636 fn fold_filter_constant(
637 predicates: &[MirScalarExpr],
638 rows: &[(Row, Diff)],
639 ) -> Result<Vec<(Row, Diff)>, EvalError> {
640 let mut new_rows = Vec::new();
641 let mut datum_vec = mz_repr::DatumVec::new();
642 'outer: for (row, diff) in rows {
643 let datums = datum_vec.borrow_with(row);
644 let temp_storage = RowArena::new();
645 for p in &*predicates {
646 if p.eval(&datums, &temp_storage)? != Datum::True {
647 continue 'outer;
648 }
649 }
650 new_rows.push((row.clone(), *diff))
651 }
652 Ok(new_rows)
653 }
654}