1use std::cmp::Ordering;
13use std::collections::{BTreeMap, BTreeSet};
14use std::convert::TryInto;
15use std::iter;
16
17use mz_expr::visit::Visit;
18use mz_expr::{
19 AggregateExpr, ColumnOrder, EvalError, MirRelationExpr, MirScalarExpr, TableFunc, UnaryFunc,
20};
21use mz_repr::{Datum, Diff, Row, RowArena, SqlRelationType};
22
23use crate::{TransformCtx, TransformError, any};
24
25#[derive(Debug)]
27pub struct FoldConstants {
28 pub limit: Option<usize>,
34}
35
36impl crate::Transform for FoldConstants {
37 fn name(&self) -> &'static str {
38 "FoldConstants"
39 }
40
41 #[mz_ore::instrument(
42 target = "optimizer",
43 level = "debug",
44 fields(path.segment = "fold_constants")
45 )]
46 fn actually_perform_transform(
47 &self,
48 relation: &mut MirRelationExpr,
49 _: &mut TransformCtx,
50 ) -> Result<(), TransformError> {
51 let mut type_stack = Vec::new();
52 let result = relation.try_visit_mut_post(&mut |e| -> Result<(), TransformError> {
53 let num_inputs = e.num_inputs();
54 let input_types = &type_stack[type_stack.len() - num_inputs..];
55 let mut relation_type = e.typ_with_input_types(input_types);
56 self.action(e, &mut relation_type)?;
57 type_stack.truncate(type_stack.len() - num_inputs);
58 type_stack.push(relation_type);
59 Ok(())
60 });
61 mz_repr::explain::trace_plan(&*relation);
62 result
63 }
64}
65
66impl FoldConstants {
67 pub fn action(
73 &self,
74 relation: &mut MirRelationExpr,
75 relation_type: &mut SqlRelationType,
76 ) -> Result<(), TransformError> {
77 match relation {
78 MirRelationExpr::Constant { .. } => { }
79 MirRelationExpr::Get { .. } => {}
80 MirRelationExpr::Let { .. } | MirRelationExpr::LetRec { .. } => {
81 }
84 MirRelationExpr::Reduce {
85 input,
86 group_key,
87 aggregates,
88 monotonic: _,
89 expected_group_size: _,
90 } => {
91 if group_key.iter().any(|e| e.contains_unmaterializable())
94 || aggregates
95 .iter()
96 .any(|a| a.expr.contains_unmaterializable())
97 {
98 return Ok(());
99 }
100
101 if let Some((rows, ..)) = (**input).as_const() {
102 let new_rows = match rows {
103 Ok(rows) => {
104 if let Some(rows) =
105 Self::fold_reduce_constant(group_key, aggregates, rows, self.limit)
106 {
107 rows
108 } else {
109 return Ok(());
110 }
111 }
112 Err(e) => Err(e.clone()),
113 };
114 *relation = MirRelationExpr::Constant {
115 rows: new_rows,
116 typ: relation_type.clone(),
117 };
118 }
119 }
120 MirRelationExpr::TopK {
121 input,
122 group_key,
123 order_key,
124 limit,
125 offset,
126 ..
127 } => {
128 if any![
136 limit.is_none(),
137 limit.as_ref().and_then(|l| l.as_literal_int64()) >= Some(0),
138 ] {
139 let limit = limit
140 .as_ref()
141 .and_then(|l| l.as_literal_int64().map(Into::into));
142 if let Some((rows, ..)) = (**input).as_const_mut() {
143 if let Ok(rows) = rows {
144 Self::fold_topk_constant(group_key, order_key, &limit, offset, rows);
145 }
146 *relation = input.take_dangerous();
147 }
148 }
149 }
150 MirRelationExpr::Negate { input } => {
151 if let Some((rows, ..)) = (**input).as_const_mut() {
152 if let Ok(rows) = rows {
153 for (_row, diff) in rows {
154 *diff = -*diff;
155 }
156 }
157 *relation = input.take_dangerous();
158 }
159 }
160 MirRelationExpr::Threshold { input } => {
161 if let Some((rows, ..)) = (**input).as_const_mut() {
162 if let Ok(rows) = rows {
163 rows.retain(|(_, diff)| diff.is_positive());
164 }
165 *relation = input.take_dangerous();
166 }
167 }
168 MirRelationExpr::Map { input, scalars } => {
169 if scalars.iter().any(|e| e.contains_unmaterializable()) {
172 return Ok(());
173 }
174
175 if let Some((rows, ..)) = (**input).as_const() {
176 if rows.as_ref().map_or(0, |r| r.len()) > 0 && scalars.len() == 1 {
181 if let MirScalarExpr::CallUnary {
182 func: UnaryFunc::Panic(_),
183 expr,
184 } = &scalars[0]
185 {
186 if let Some("forced panic") = expr.as_literal_str() {
187 let msg = "forced panic".to_string();
188 return Err(TransformError::CallerShouldPanic(msg));
189 }
190 }
191 }
192
193 let new_rows = match rows {
194 Ok(rows) => rows
195 .iter()
196 .map(|(input_row, diff)| {
197 let mut unpacked = input_row.unpack();
199 let temp_storage = RowArena::new();
200 for scalar in scalars.iter() {
201 unpacked.push(scalar.eval(&unpacked, &temp_storage)?)
202 }
203 Ok::<_, EvalError>((Row::pack_slice(&unpacked), *diff))
204 })
205 .collect::<Result<_, _>>(),
206 Err(e) => Err(e.clone()),
207 };
208 *relation = MirRelationExpr::Constant {
209 rows: new_rows,
210 typ: relation_type.clone(),
211 };
212 }
213 }
214 MirRelationExpr::FlatMap { input, func, exprs } => {
215 if exprs.iter().any(|e| e.contains_unmaterializable()) {
217 return Ok(());
218 }
219
220 if let Some((rows, ..)) = (**input).as_const() {
221 let new_rows = match rows {
222 Ok(rows) => Self::fold_flat_map_constant(func, exprs, rows, self.limit),
223 Err(e) => Err(e.clone()),
224 };
225 match new_rows {
226 Ok(None) => {}
227 Ok(Some(rows)) => {
228 *relation = MirRelationExpr::Constant {
229 rows: Ok(rows),
230 typ: relation_type.clone(),
231 };
232 }
233 Err(err) => {
234 *relation = MirRelationExpr::Constant {
235 rows: Err(err),
236 typ: relation_type.clone(),
237 };
238 }
239 };
240 }
241 }
242 MirRelationExpr::Filter { input, predicates } => {
243 if predicates.iter().any(|e| e.contains_unmaterializable()) {
246 return Ok(());
247 }
248
249 if predicates
251 .iter()
252 .any(|p| p.is_literal_false() || p.is_literal_null())
253 {
254 relation.take_safely(Some(relation_type.clone()));
255 } else if let Some((rows, ..)) = (**input).as_const() {
256 predicates.sort_by_key(|p| p.is_literal_err());
258 let new_rows = match rows {
259 Ok(rows) => Self::fold_filter_constant(predicates, rows),
260 Err(e) => Err(e.clone()),
261 };
262 *relation = MirRelationExpr::Constant {
263 rows: new_rows,
264 typ: relation_type.clone(),
265 };
266 }
267 }
268 MirRelationExpr::Project { input, outputs } => {
269 if let Some((rows, ..)) = (**input).as_const() {
270 let mut row_buf = Row::default();
271 let new_rows = match rows {
272 Ok(rows) => Ok(rows
273 .iter()
274 .map(|(input_row, diff)| {
275 let datums = input_row.unpack();
277 row_buf.packer().extend(outputs.iter().map(|i| &datums[*i]));
278 (row_buf.clone(), *diff)
279 })
280 .collect()),
281 Err(e) => Err(e.clone()),
282 };
283 *relation = MirRelationExpr::Constant {
284 rows: new_rows,
285 typ: relation_type.clone(),
286 };
287 }
288 }
289 MirRelationExpr::Join {
290 inputs,
291 equivalences,
292 ..
293 } => {
294 if inputs.iter().any(|e| e.is_empty()) {
295 relation.take_safely(Some(relation_type.clone()));
296 } else if let Some(e) = inputs.iter().find_map(|i| i.as_const_err()) {
297 *relation = MirRelationExpr::Constant {
298 rows: Err(e.clone()),
299 typ: relation_type.clone(),
300 };
301 } else if inputs
302 .iter()
303 .all(|i| matches!(i.as_const(), Some((Ok(_), ..))))
304 {
305 if equivalences
307 .iter()
308 .any(|equiv| equiv.iter().any(|e| e.contains_unmaterializable()))
309 {
310 return Ok(());
311 }
312
313 let mut old_rows = vec![(Row::pack::<_, Datum>(None), Diff::ONE)];
316 let mut row_buf = Row::default();
317 for input in inputs.iter() {
318 if let Some((Ok(rows), ..)) = input.as_const() {
319 if let Some(limit) = self.limit {
320 if old_rows.len() * rows.len() > limit {
321 return Ok(());
325 }
326 }
327 let mut next_rows = Vec::new();
328 for (old_row, old_count) in old_rows {
329 for (new_row, new_count) in rows.iter() {
330 let mut packer = row_buf.packer();
331 packer.extend_by_row(&old_row);
332 packer.extend_by_row(new_row);
333 next_rows.push((row_buf.clone(), old_count * *new_count));
334 }
335 }
336 old_rows = next_rows;
337 }
338 }
339
340 let mut datum_vec = mz_repr::DatumVec::new();
342 old_rows.retain(|(row, _count)| {
343 let datums = datum_vec.borrow_with(row);
344 let temp_storage = RowArena::new();
345 equivalences.iter().all(|equivalence| {
346 let mut values =
347 equivalence.iter().map(|e| e.eval(&datums, &temp_storage));
348 if let Some(value) = values.next() {
349 values.all(|v| v == value)
350 } else {
351 true
352 }
353 })
354 });
355
356 *relation = MirRelationExpr::Constant {
357 rows: Ok(old_rows),
358 typ: relation_type.clone(),
359 };
360 }
361 }
363 MirRelationExpr::Union { base, inputs } => {
364 if let Some(e) = iter::once(&mut **base)
365 .chain(&mut *inputs)
366 .find_map(|i| i.as_const_err())
367 {
368 *relation = MirRelationExpr::Constant {
369 rows: Err(e.clone()),
370 typ: relation_type.clone(),
371 };
372 } else {
373 let mut rows = vec![];
374 let mut new_inputs = vec![];
375
376 for input in iter::once(&mut **base).chain(&mut *inputs) {
377 if let Some((Ok(rs), ..)) = input.as_const() {
378 rows.extend(rs.clone());
379 } else {
380 new_inputs.push(input.clone())
381 }
382 }
383 if !rows.is_empty() {
384 new_inputs.push(MirRelationExpr::Constant {
385 rows: Ok(rows),
386 typ: relation_type.clone(),
387 });
388 }
389
390 *relation = MirRelationExpr::union_many(new_inputs, relation_type.clone());
391 }
392 }
393 MirRelationExpr::ArrangeBy { .. } => {
394 }
397 }
398
399 if let Some((Ok(rows), typ)) = relation.as_const_mut() {
404 differential_dataflow::consolidation::consolidate(rows);
406
407 for col_type in typ.column_types.iter_mut() {
409 col_type.nullable = false;
410 }
411 for (row, _) in rows.iter_mut() {
412 for (index, datum) in row.iter().enumerate() {
413 if datum.is_null() {
414 typ.column_types[index].nullable = true;
415 }
416 }
417 }
418 *relation_type = typ.clone();
419 }
420
421 Ok(())
422 }
423
424 #[allow(clippy::as_conversions)]
427 fn fold_reduce_constant(
428 group_key: &[MirScalarExpr],
429 aggregates: &[AggregateExpr],
430 rows: &[(Row, Diff)],
431 limit: Option<usize>,
432 ) -> Option<Result<Vec<(Row, Diff)>, EvalError>> {
433 let mut groups = BTreeMap::new();
437 let temp_storage2 = RowArena::new();
438 let mut row_buf = Row::default();
439 let mut limit_remaining =
440 limit.map_or(Diff::MAX, |limit| Diff::try_from(limit).expect("must fit"));
441 for (row, diff) in rows {
442 if *diff <= Diff::ZERO {
447 return Some(Err(EvalError::InvalidParameterValue(
448 "constant folding encountered reduce on collection with non-positive multiplicities".into()
449 )));
450 }
451
452 if limit_remaining < *diff {
453 return None;
454 }
455 limit_remaining -= diff;
456
457 let datums = row.unpack();
458 let temp_storage = RowArena::new();
459 let key = match group_key
460 .iter()
461 .map(|e| e.eval(&datums, &temp_storage2))
462 .collect::<Result<Vec<_>, _>>()
463 {
464 Ok(key) => key,
465 Err(e) => return Some(Err(e)),
466 };
467 let val = match aggregates
468 .iter()
469 .map(|agg| {
470 row_buf
471 .packer()
472 .extend([agg.expr.eval(&datums, &temp_storage)?]);
473 Ok::<_, EvalError>(row_buf.clone())
474 })
475 .collect::<Result<Vec<_>, _>>()
476 {
477 Ok(val) => val,
478 Err(e) => return Some(Err(e)),
479 };
480 let entry = groups.entry(key).or_insert_with(Vec::new);
481 for _ in 0..diff.into_inner() {
482 entry.push(val.clone());
483 }
484 }
485
486 let new_rows = groups
492 .into_iter()
493 .map({
494 let mut row_buf = Row::default();
495 move |(key, vals)| {
496 let temp_storage = RowArena::new();
497 row_buf.packer().extend(key.into_iter().chain(
498 aggregates.iter().enumerate().map(|(i, agg)| {
499 if agg.distinct {
500 agg.func.eval(
501 vals.iter()
502 .map(|val| val[i].unpack_first())
503 .collect::<BTreeSet<_>>(),
504 &temp_storage,
505 )
506 } else {
507 agg.func.eval(
508 vals.iter().map(|val| val[i].unpack_first()),
509 &temp_storage,
510 )
511 }
512 }),
513 ));
514 (row_buf.clone(), Diff::ONE)
515 }
516 })
517 .collect();
518 Some(Ok(new_rows))
519 }
520
521 fn fold_topk_constant<'a>(
522 group_key: &[usize],
523 order_key: &[ColumnOrder],
524 limit: &Option<Diff>,
525 offset: &usize,
526 rows: &'a mut [(Row, Diff)],
527 ) {
528 let mut lhs_datum_vec = mz_repr::DatumVec::new();
530 let mut rhs_datum_vec = mz_repr::DatumVec::new();
531 let mut cmp_order_key = |lhs: &(Row, Diff), rhs: &(Row, Diff)| {
532 let lhs_datums = &lhs_datum_vec.borrow_with(&lhs.0);
533 let rhs_datums = &rhs_datum_vec.borrow_with(&rhs.0);
534 mz_expr::compare_columns(order_key, lhs_datums, rhs_datums, || lhs.cmp(rhs))
535 };
536 let mut cmp_group_key = {
537 let group_key = group_key
538 .iter()
539 .map(|column| ColumnOrder {
540 column: *column,
541 desc: false,
545 nulls_last: false,
546 })
547 .collect::<Vec<ColumnOrder>>();
548 let mut lhs_datum_vec = mz_repr::DatumVec::new();
549 let mut rhs_datum_vec = mz_repr::DatumVec::new();
550 move |lhs: &(Row, Diff), rhs: &(Row, Diff)| {
551 let lhs_datums = &lhs_datum_vec.borrow_with(&lhs.0);
552 let rhs_datums = &rhs_datum_vec.borrow_with(&rhs.0);
553 mz_expr::compare_columns(&group_key, lhs_datums, rhs_datums, || Ordering::Equal)
554 }
555 };
556
557 rows.sort_by(&mut cmp_order_key);
559
560 if !group_key.is_empty() {
562 rows.sort_by(&mut cmp_group_key);
563 };
564
565 let mut same_group_key =
566 |lhs: &(Row, Diff), rhs: &(Row, Diff)| cmp_group_key(lhs, rhs) == Ordering::Equal;
567
568 let mut cursor = 0;
569 while cursor < rows.len() {
570 let mut offset_rem: Diff = offset.clone().try_into().unwrap();
572 let mut limit_rem: Option<Diff> = limit.clone();
573
574 let mut finger = cursor;
575 while finger < rows.len() && same_group_key(&rows[cursor], &rows[finger]) {
576 if rows[finger].1.is_negative() {
577 rows[finger].1 = Diff::ZERO;
579 } else {
580 let rows_to_ignore = std::cmp::min(offset_rem, rows[finger].1);
583 rows[finger].1 -= rows_to_ignore;
584 offset_rem -= rows_to_ignore;
585 if let Some(limit_rem) = &mut limit_rem {
588 let rows_to_retain = std::cmp::min(*limit_rem, rows[finger].1);
589 rows[finger].1 = rows_to_retain;
590 *limit_rem -= rows_to_retain;
591 }
592 }
593 finger += 1;
594 }
595 cursor = finger;
596 }
597 }
598
599 fn fold_flat_map_constant(
600 func: &TableFunc,
601 exprs: &[MirScalarExpr],
602 rows: &[(Row, Diff)],
603 limit: Option<usize>,
604 ) -> Result<Option<Vec<(Row, Diff)>>, EvalError> {
605 let limit = limit.unwrap_or(usize::MAX);
607 let mut new_rows = Vec::new();
608 let mut row_buf = Row::default();
609 let mut datum_vec = mz_repr::DatumVec::new();
610 for (input_row, diff) in rows {
611 let datums = datum_vec.borrow_with(input_row);
612 let temp_storage = RowArena::new();
613 let datums = exprs
614 .iter()
615 .map(|expr| expr.eval(&datums, &temp_storage))
616 .collect::<Result<Vec<_>, _>>()?;
617 let mut output_rows = func.eval(&datums, &temp_storage)?.fuse();
618 for (output_row, diff2) in (&mut output_rows).take(limit - new_rows.len()) {
619 let mut packer = row_buf.packer();
620 packer.extend_by_row(input_row);
621 packer.extend_by_row(&output_row);
622 new_rows.push((row_buf.clone(), diff2 * *diff))
623 }
624 if output_rows.next() != None {
627 return Ok(None);
628 }
629 }
630 Ok(Some(new_rows))
631 }
632
633 fn fold_filter_constant(
634 predicates: &[MirScalarExpr],
635 rows: &[(Row, Diff)],
636 ) -> Result<Vec<(Row, Diff)>, EvalError> {
637 let mut new_rows = Vec::new();
638 let mut datum_vec = mz_repr::DatumVec::new();
639 'outer: for (row, diff) in rows {
640 let datums = datum_vec.borrow_with(row);
641 let temp_storage = RowArena::new();
642 for p in &*predicates {
643 if p.eval(&datums, &temp_storage)? != Datum::True {
644 continue 'outer;
645 }
646 }
647 new_rows.push((row.clone(), *diff))
648 }
649 Ok(new_rows)
650 }
651}