1use std::cmp::Ordering;
13use std::collections::{BTreeMap, BTreeSet};
14use std::convert::TryInto;
15use std::iter;
16
17use mz_expr::visit::Visit;
18use mz_expr::{
19 AggregateExpr, ColumnOrder, EvalError, MirRelationExpr, MirScalarExpr, TableFunc, UnaryFunc,
20};
21use mz_repr::{Datum, Diff, RelationType, Row, RowArena};
22
23use crate::{TransformCtx, TransformError, any};
24
25#[derive(Debug)]
27pub struct FoldConstants {
28 pub limit: Option<usize>,
34}
35
36impl crate::Transform for FoldConstants {
37 fn name(&self) -> &'static str {
38 "FoldConstants"
39 }
40
41 #[mz_ore::instrument(
42 target = "optimizer",
43 level = "debug",
44 fields(path.segment = "fold_constants")
45 )]
46 fn actually_perform_transform(
47 &self,
48 relation: &mut MirRelationExpr,
49 _: &mut TransformCtx,
50 ) -> Result<(), TransformError> {
51 let mut type_stack = Vec::new();
52 let result = relation.try_visit_mut_post(&mut |e| -> Result<(), TransformError> {
53 let num_inputs = e.num_inputs();
54 let input_types = &type_stack[type_stack.len() - num_inputs..];
55 let mut relation_type = e.typ_with_input_types(input_types);
56 self.action(e, &mut relation_type)?;
57 type_stack.truncate(type_stack.len() - num_inputs);
58 type_stack.push(relation_type);
59 Ok(())
60 });
61 mz_repr::explain::trace_plan(&*relation);
62 result
63 }
64}
65
66impl FoldConstants {
67 pub fn action(
73 &self,
74 relation: &mut MirRelationExpr,
75 relation_type: &mut RelationType,
76 ) -> Result<(), TransformError> {
77 match relation {
78 MirRelationExpr::Constant { .. } => { }
79 MirRelationExpr::Get { .. } => {}
80 MirRelationExpr::Let { .. } | MirRelationExpr::LetRec { .. } => {
81 }
84 MirRelationExpr::Reduce {
85 input,
86 group_key,
87 aggregates,
88 monotonic: _,
89 expected_group_size: _,
90 } => {
91 if group_key.iter().any(|e| e.contains_unmaterializable())
94 || aggregates
95 .iter()
96 .any(|a| a.expr.contains_unmaterializable())
97 {
98 return Ok(());
99 }
100
101 if let Some((rows, ..)) = (**input).as_const() {
102 let new_rows = match rows {
103 Ok(rows) => {
104 if let Some(rows) =
105 Self::fold_reduce_constant(group_key, aggregates, rows, self.limit)
106 {
107 rows
108 } else {
109 return Ok(());
110 }
111 }
112 Err(e) => Err(e.clone()),
113 };
114 *relation = MirRelationExpr::Constant {
115 rows: new_rows,
116 typ: relation_type.clone(),
117 };
118 }
119 }
120 MirRelationExpr::TopK {
121 input,
122 group_key,
123 order_key,
124 limit,
125 offset,
126 ..
127 } => {
128 if any![
136 limit.is_none(),
137 limit.as_ref().and_then(|l| l.as_literal_int64()) >= Some(0),
138 ] {
139 let limit = limit
140 .as_ref()
141 .and_then(|l| l.as_literal_int64().map(Into::into));
142 if let Some((rows, ..)) = (**input).as_const_mut() {
143 if let Ok(rows) = rows {
144 Self::fold_topk_constant(group_key, order_key, &limit, offset, rows);
145 }
146 *relation = input.take_dangerous();
147 }
148 }
149 }
150 MirRelationExpr::Negate { input } => {
151 if let Some((rows, ..)) = (**input).as_const_mut() {
152 if let Ok(rows) = rows {
153 for (_row, diff) in rows {
154 *diff = -*diff;
155 }
156 }
157 *relation = input.take_dangerous();
158 }
159 }
160 MirRelationExpr::Threshold { input } => {
161 if let Some((rows, ..)) = (**input).as_const_mut() {
162 if let Ok(rows) = rows {
163 rows.retain(|(_, diff)| diff.is_positive());
164 }
165 *relation = input.take_dangerous();
166 }
167 }
168 MirRelationExpr::Map { input, scalars } => {
169 if scalars.iter().any(|e| e.contains_unmaterializable()) {
172 return Ok(());
173 }
174
175 if let Some((rows, ..)) = (**input).as_const() {
176 if rows.as_ref().map_or(0, |r| r.len()) > 0 && scalars.len() == 1 {
181 if let MirScalarExpr::CallUnary {
182 func: UnaryFunc::Panic(_),
183 expr,
184 } = &scalars[0]
185 {
186 if let Some("forced panic") = expr.as_literal_str() {
187 let msg = "forced panic".to_string();
188 return Err(TransformError::CallerShouldPanic(msg));
189 }
190 }
191 }
192
193 let new_rows = match rows {
194 Ok(rows) => rows
195 .iter()
196 .cloned()
197 .map(|(input_row, diff)| {
198 let mut unpacked = input_row.unpack();
200 let temp_storage = RowArena::new();
201 for scalar in scalars.iter() {
202 unpacked.push(scalar.eval(&unpacked, &temp_storage)?)
203 }
204 Ok::<_, EvalError>((Row::pack_slice(&unpacked), diff))
205 })
206 .collect::<Result<_, _>>(),
207 Err(e) => Err(e.clone()),
208 };
209 *relation = MirRelationExpr::Constant {
210 rows: new_rows,
211 typ: relation_type.clone(),
212 };
213 }
214 }
215 MirRelationExpr::FlatMap { input, func, exprs } => {
216 if exprs.iter().any(|e| e.contains_unmaterializable()) {
218 return Ok(());
219 }
220
221 if let Some((rows, ..)) = (**input).as_const() {
222 let new_rows = match rows {
223 Ok(rows) => Self::fold_flat_map_constant(func, exprs, rows, self.limit),
224 Err(e) => Err(e.clone()),
225 };
226 match new_rows {
227 Ok(None) => {}
228 Ok(Some(rows)) => {
229 *relation = MirRelationExpr::Constant {
230 rows: Ok(rows),
231 typ: relation_type.clone(),
232 };
233 }
234 Err(err) => {
235 *relation = MirRelationExpr::Constant {
236 rows: Err(err),
237 typ: relation_type.clone(),
238 };
239 }
240 };
241 }
242 }
243 MirRelationExpr::Filter { input, predicates } => {
244 if predicates.iter().any(|e| e.contains_unmaterializable()) {
247 return Ok(());
248 }
249
250 if predicates
252 .iter()
253 .any(|p| p.is_literal_false() || p.is_literal_null())
254 {
255 relation.take_safely(Some(relation_type.clone()));
256 } else if let Some((rows, ..)) = (**input).as_const() {
257 predicates.sort_by_key(|p| p.is_literal_err());
259 let new_rows = match rows {
260 Ok(rows) => Self::fold_filter_constant(predicates, rows),
261 Err(e) => Err(e.clone()),
262 };
263 *relation = MirRelationExpr::Constant {
264 rows: new_rows,
265 typ: relation_type.clone(),
266 };
267 }
268 }
269 MirRelationExpr::Project { input, outputs } => {
270 if let Some((rows, ..)) = (**input).as_const() {
271 let mut row_buf = Row::default();
272 let new_rows = match rows {
273 Ok(rows) => Ok(rows
274 .iter()
275 .map(|(input_row, diff)| {
276 let datums = input_row.unpack();
278 row_buf.packer().extend(outputs.iter().map(|i| &datums[*i]));
279 (row_buf.clone(), *diff)
280 })
281 .collect()),
282 Err(e) => Err(e.clone()),
283 };
284 *relation = MirRelationExpr::Constant {
285 rows: new_rows,
286 typ: relation_type.clone(),
287 };
288 }
289 }
290 MirRelationExpr::Join {
291 inputs,
292 equivalences,
293 ..
294 } => {
295 if inputs.iter().any(|e| e.is_empty()) {
296 relation.take_safely(Some(relation_type.clone()));
297 } else if let Some(e) = inputs.iter().find_map(|i| i.as_const_err()) {
298 *relation = MirRelationExpr::Constant {
299 rows: Err(e.clone()),
300 typ: relation_type.clone(),
301 };
302 } else if inputs
303 .iter()
304 .all(|i| matches!(i.as_const(), Some((Ok(_), ..))))
305 {
306 if equivalences
308 .iter()
309 .any(|equiv| equiv.iter().any(|e| e.contains_unmaterializable()))
310 {
311 return Ok(());
312 }
313
314 let mut old_rows = vec![(Row::pack::<_, Datum>(None), Diff::ONE)];
317 let mut row_buf = Row::default();
318 for input in inputs.iter() {
319 if let Some((Ok(rows), ..)) = input.as_const() {
320 if let Some(limit) = self.limit {
321 if old_rows.len() * rows.len() > limit {
322 return Ok(());
326 }
327 }
328 let mut next_rows = Vec::new();
329 for (old_row, old_count) in old_rows {
330 for (new_row, new_count) in rows.iter() {
331 let mut packer = row_buf.packer();
332 packer.extend_by_row(&old_row);
333 packer.extend_by_row(new_row);
334 next_rows.push((row_buf.clone(), old_count * *new_count));
335 }
336 }
337 old_rows = next_rows;
338 }
339 }
340
341 let mut datum_vec = mz_repr::DatumVec::new();
343 old_rows.retain(|(row, _count)| {
344 let datums = datum_vec.borrow_with(row);
345 let temp_storage = RowArena::new();
346 equivalences.iter().all(|equivalence| {
347 let mut values =
348 equivalence.iter().map(|e| e.eval(&datums, &temp_storage));
349 if let Some(value) = values.next() {
350 values.all(|v| v == value)
351 } else {
352 true
353 }
354 })
355 });
356
357 *relation = MirRelationExpr::Constant {
358 rows: Ok(old_rows),
359 typ: relation_type.clone(),
360 };
361 }
362 }
364 MirRelationExpr::Union { base, inputs } => {
365 if let Some(e) = iter::once(&mut **base)
366 .chain(&mut *inputs)
367 .find_map(|i| i.as_const_err())
368 {
369 *relation = MirRelationExpr::Constant {
370 rows: Err(e.clone()),
371 typ: relation_type.clone(),
372 };
373 } else {
374 let mut rows = vec![];
375 let mut new_inputs = vec![];
376
377 for input in iter::once(&mut **base).chain(&mut *inputs) {
378 if let Some((Ok(rs), ..)) = input.as_const() {
379 rows.extend(rs.clone());
380 } else {
381 new_inputs.push(input.clone())
382 }
383 }
384 if !rows.is_empty() {
385 new_inputs.push(MirRelationExpr::Constant {
386 rows: Ok(rows),
387 typ: relation_type.clone(),
388 });
389 }
390
391 *relation = MirRelationExpr::union_many(new_inputs, relation_type.clone());
392 }
393 }
394 MirRelationExpr::ArrangeBy { .. } => {
395 }
398 }
399
400 if let Some((Ok(rows), typ)) = relation.as_const_mut() {
405 differential_dataflow::consolidation::consolidate(rows);
407
408 for col_type in typ.column_types.iter_mut() {
410 col_type.nullable = false;
411 }
412 for (row, _) in rows.iter_mut() {
413 for (index, datum) in row.iter().enumerate() {
414 if datum.is_null() {
415 typ.column_types[index].nullable = true;
416 }
417 }
418 }
419 *relation_type = typ.clone();
420 }
421
422 Ok(())
423 }
424
425 #[allow(clippy::as_conversions)]
428 fn fold_reduce_constant(
429 group_key: &[MirScalarExpr],
430 aggregates: &[AggregateExpr],
431 rows: &[(Row, Diff)],
432 limit: Option<usize>,
433 ) -> Option<Result<Vec<(Row, Diff)>, EvalError>> {
434 let mut groups = BTreeMap::new();
438 let temp_storage2 = RowArena::new();
439 let mut row_buf = Row::default();
440 let mut limit_remaining =
441 limit.map_or(Diff::MAX, |limit| Diff::try_from(limit).expect("must fit"));
442 for (row, diff) in rows {
443 if *diff <= Diff::ZERO {
448 return Some(Err(EvalError::InvalidParameterValue(
449 "constant folding encountered reduce on collection with non-positive multiplicities".into()
450 )));
451 }
452
453 if limit_remaining < *diff {
454 return None;
455 }
456 limit_remaining -= diff;
457
458 let datums = row.unpack();
459 let temp_storage = RowArena::new();
460 let key = match group_key
461 .iter()
462 .map(|e| e.eval(&datums, &temp_storage2))
463 .collect::<Result<Vec<_>, _>>()
464 {
465 Ok(key) => key,
466 Err(e) => return Some(Err(e)),
467 };
468 let val = match aggregates
469 .iter()
470 .map(|agg| {
471 row_buf
472 .packer()
473 .extend([agg.expr.eval(&datums, &temp_storage)?]);
474 Ok::<_, EvalError>(row_buf.clone())
475 })
476 .collect::<Result<Vec<_>, _>>()
477 {
478 Ok(val) => val,
479 Err(e) => return Some(Err(e)),
480 };
481 let entry = groups.entry(key).or_insert_with(Vec::new);
482 for _ in 0..diff.into_inner() {
483 entry.push(val.clone());
484 }
485 }
486
487 let new_rows = groups
493 .into_iter()
494 .map({
495 let mut row_buf = Row::default();
496 move |(key, vals)| {
497 let temp_storage = RowArena::new();
498 row_buf.packer().extend(key.into_iter().chain(
499 aggregates.iter().enumerate().map(|(i, agg)| {
500 if agg.distinct {
501 agg.func.eval(
502 vals.iter()
503 .map(|val| val[i].unpack_first())
504 .collect::<BTreeSet<_>>(),
505 &temp_storage,
506 )
507 } else {
508 agg.func.eval(
509 vals.iter().map(|val| val[i].unpack_first()),
510 &temp_storage,
511 )
512 }
513 }),
514 ));
515 (row_buf.clone(), Diff::ONE)
516 }
517 })
518 .collect();
519 Some(Ok(new_rows))
520 }
521
522 fn fold_topk_constant<'a>(
523 group_key: &[usize],
524 order_key: &[ColumnOrder],
525 limit: &Option<Diff>,
526 offset: &usize,
527 rows: &'a mut [(Row, Diff)],
528 ) {
529 let mut lhs_datum_vec = mz_repr::DatumVec::new();
531 let mut rhs_datum_vec = mz_repr::DatumVec::new();
532 let mut cmp_order_key = |lhs: &(Row, Diff), rhs: &(Row, Diff)| {
533 let lhs_datums = &lhs_datum_vec.borrow_with(&lhs.0);
534 let rhs_datums = &rhs_datum_vec.borrow_with(&rhs.0);
535 mz_expr::compare_columns(order_key, lhs_datums, rhs_datums, || lhs.cmp(rhs))
536 };
537 let mut cmp_group_key = {
538 let group_key = group_key
539 .iter()
540 .map(|column| ColumnOrder {
541 column: *column,
542 desc: false,
546 nulls_last: false,
547 })
548 .collect::<Vec<ColumnOrder>>();
549 let mut lhs_datum_vec = mz_repr::DatumVec::new();
550 let mut rhs_datum_vec = mz_repr::DatumVec::new();
551 move |lhs: &(Row, Diff), rhs: &(Row, Diff)| {
552 let lhs_datums = &lhs_datum_vec.borrow_with(&lhs.0);
553 let rhs_datums = &rhs_datum_vec.borrow_with(&rhs.0);
554 mz_expr::compare_columns(&group_key, lhs_datums, rhs_datums, || Ordering::Equal)
555 }
556 };
557
558 rows.sort_by(&mut cmp_order_key);
560
561 if !group_key.is_empty() {
563 rows.sort_by(&mut cmp_group_key);
564 };
565
566 let mut same_group_key =
567 |lhs: &(Row, Diff), rhs: &(Row, Diff)| cmp_group_key(lhs, rhs) == Ordering::Equal;
568
569 let mut cursor = 0;
570 while cursor < rows.len() {
571 let mut offset_rem: Diff = offset.clone().try_into().unwrap();
573 let mut limit_rem: Option<Diff> = limit.clone();
574
575 let mut finger = cursor;
576 while finger < rows.len() && same_group_key(&rows[cursor], &rows[finger]) {
577 if rows[finger].1.is_negative() {
578 rows[finger].1 = Diff::ZERO;
580 } else {
581 let rows_to_ignore = std::cmp::min(offset_rem, rows[finger].1);
584 rows[finger].1 -= rows_to_ignore;
585 offset_rem -= rows_to_ignore;
586 if let Some(limit_rem) = &mut limit_rem {
589 let rows_to_retain = std::cmp::min(*limit_rem, rows[finger].1);
590 rows[finger].1 = rows_to_retain;
591 *limit_rem -= rows_to_retain;
592 }
593 }
594 finger += 1;
595 }
596 cursor = finger;
597 }
598 }
599
600 fn fold_flat_map_constant(
601 func: &TableFunc,
602 exprs: &[MirScalarExpr],
603 rows: &[(Row, Diff)],
604 limit: Option<usize>,
605 ) -> Result<Option<Vec<(Row, Diff)>>, EvalError> {
606 let limit = limit.unwrap_or(usize::MAX);
608 let mut new_rows = Vec::new();
609 let mut row_buf = Row::default();
610 let mut datum_vec = mz_repr::DatumVec::new();
611 for (input_row, diff) in rows {
612 let datums = datum_vec.borrow_with(input_row);
613 let temp_storage = RowArena::new();
614 let datums = exprs
615 .iter()
616 .map(|expr| expr.eval(&datums, &temp_storage))
617 .collect::<Result<Vec<_>, _>>()?;
618 let mut output_rows = func.eval(&datums, &temp_storage)?.fuse();
619 for (output_row, diff2) in (&mut output_rows).take(limit - new_rows.len()) {
620 let mut packer = row_buf.packer();
621 packer.extend_by_row(input_row);
622 packer.extend_by_row(&output_row);
623 new_rows.push((row_buf.clone(), diff2 * *diff))
624 }
625 if output_rows.next() != None {
628 return Ok(None);
629 }
630 }
631 Ok(Some(new_rows))
632 }
633
634 fn fold_filter_constant(
635 predicates: &[MirScalarExpr],
636 rows: &[(Row, Diff)],
637 ) -> Result<Vec<(Row, Diff)>, EvalError> {
638 let mut new_rows = Vec::new();
639 let mut datum_vec = mz_repr::DatumVec::new();
640 'outer: for (row, diff) in rows {
641 let datums = datum_vec.borrow_with(row);
642 let temp_storage = RowArena::new();
643 for p in &*predicates {
644 if p.eval(&datums, &temp_storage)? != Datum::True {
645 continue 'outer;
646 }
647 }
648 new_rows.push((row.clone(), *diff))
649 }
650 Ok(new_rows)
651 }
652}