1use std::collections::{BTreeMap, BTreeSet};
20
21use itertools::Itertools;
22use mz_expr::visit::Visit;
23use mz_expr::{BinaryFunc, MirRelationExpr, MirScalarExpr, VariadicFunc};
24use mz_repr::{ReprColumnType, Row, SqlColumnType};
25
26use crate::analysis::{DerivedBuilder, ReprRelationType};
27use crate::{Transform, TransformCtx, TransformError};
28
29#[derive(Debug)]
32pub struct CaseLiteralTransform;
33
34impl Transform for CaseLiteralTransform {
35 fn name(&self) -> &'static str {
36 "CaseLiteralTransform"
37 }
38
39 #[mz_ore::instrument(
40 target = "optimizer",
41 level = "debug",
42 fields(path.segment = "case_literal")
43 )]
44 fn actually_perform_transform(
45 &self,
46 relation: &mut MirRelationExpr,
47 ctx: &mut TransformCtx,
48 ) -> Result<(), TransformError> {
49 let mut builder = DerivedBuilder::new(ctx.features);
51 builder.require(ReprRelationType);
52 let derived = builder.visit(&*relation);
53
54 let mut todo = vec![(&mut *relation, derived.as_view())];
55 while let Some((expr, view)) = todo.pop() {
56 match expr {
57 MirRelationExpr::Map { scalars, .. } => {
58 let output_type: &Vec<ReprColumnType> = view
60 .value::<ReprRelationType>()
61 .expect("ReprRelationType required")
62 .as_ref()
63 .unwrap();
64 let input_arity = output_type.len() - scalars.len();
65 for (index, scalar) in scalars.iter_mut().enumerate() {
66 Self::rewrite_scalar(scalar, &output_type[..input_arity + index])?;
67 }
68 }
69 MirRelationExpr::Filter { predicates, .. } => {
70 let input_type: &Vec<ReprColumnType> = view
71 .last_child()
72 .value::<ReprRelationType>()
73 .expect("ReprRelationType required")
74 .as_ref()
75 .unwrap();
76 for predicate in predicates.iter_mut() {
77 Self::rewrite_scalar(predicate, input_type)?;
78 }
79 }
80 MirRelationExpr::Reduce { aggregates, .. } => {
81 let input_type: &Vec<ReprColumnType> = view
82 .last_child()
83 .value::<ReprRelationType>()
84 .expect("ReprRelationType required")
85 .as_ref()
86 .unwrap();
87 for agg in aggregates.iter_mut() {
88 Self::rewrite_scalar(&mut agg.expr, input_type)?;
89 }
90 }
91 MirRelationExpr::FlatMap { exprs, .. } => {
92 let input_type: &Vec<ReprColumnType> = view
93 .last_child()
94 .value::<ReprRelationType>()
95 .expect("ReprRelationType required")
96 .as_ref()
97 .unwrap();
98 for e in exprs.iter_mut() {
99 Self::rewrite_scalar(e, input_type)?;
100 }
101 }
102 MirRelationExpr::Join { equivalences, .. } => {
103 let mut children: Vec<_> = view.children_rev().collect::<Vec<_>>();
104 children.reverse();
105 let input_types: Vec<ReprColumnType> = children
106 .iter()
107 .flat_map(|c| {
108 c.value::<ReprRelationType>()
109 .expect("ReprRelationType required")
110 .as_ref()
111 .unwrap()
112 .iter()
113 .cloned()
114 })
115 .collect();
116 for class in equivalences.iter_mut() {
117 for expr in class.iter_mut() {
118 Self::rewrite_scalar(expr, &input_types)?;
119 }
120 }
121 }
122 MirRelationExpr::TopK { limit, .. } => {
123 let input_type: &Vec<ReprColumnType> = view
124 .last_child()
125 .value::<ReprRelationType>()
126 .expect("ReprRelationType required")
127 .as_ref()
128 .unwrap();
129 if let Some(limit) = limit {
130 Self::rewrite_scalar(limit, input_type)?;
131 }
132 }
133 _ => {}
134 }
135 todo.extend(expr.children_mut().rev().zip_eq(view.children_rev()));
136 }
137
138 mz_repr::explain::trace_plan(&*relation);
139 Ok(())
140 }
141}
142
143impl CaseLiteralTransform {
144 fn rewrite_scalar(
147 expr: &mut MirScalarExpr,
148 column_types: &[ReprColumnType],
149 ) -> Result<(), TransformError> {
150 expr.try_visit_mut_post(&mut |node: &mut MirScalarExpr| {
151 try_fold_into_case_literal(node);
152 try_create_case_literal(node, column_types);
153 Ok(())
154 })
155 }
156}
157
158fn try_fold_into_case_literal(expr: &mut MirScalarExpr) {
166 let MirScalarExpr::If { cond, then, els } = expr else {
167 return;
168 };
169 let Some((common_candidate, literal_row)) = peek_eq_literal(cond) else {
170 return;
171 };
172 let MirScalarExpr::CallVariadic {
173 func: VariadicFunc::CaseLiteral(cl),
174 exprs,
175 } = els.as_mut()
176 else {
177 return;
178 };
179
180 if exprs[0] != *common_candidate {
182 return;
183 }
184
185 if let Some(&existing_idx) = cl.lookup.get(literal_row) {
186 exprs[existing_idx] = then.take();
188 } else {
189 let new_idx = exprs.len() - 1;
191 exprs.insert(new_idx, then.take());
192 cl.lookup.insert(literal_row.clone(), new_idx);
193 }
194
195 *expr = els.take();
197}
198
199fn try_create_case_literal(expr: &mut MirScalarExpr, column_types: &[ReprColumnType]) {
202 if !has_at_least_two_arms(expr) {
203 return;
204 }
205
206 let chain = expr.take();
208 let (collected_cases, common, els) = collect_if_chain_arms(chain);
209
210 let common = common.expect("common expr must be set when arm_count >= 2");
211
212 let mut return_type: Option<ReprColumnType> = None;
214 for (_, result) in &collected_cases {
215 let t = result.typ(column_types);
216 return_type = Some(match return_type {
217 None => t,
218 Some(prev) => prev.union(&t).expect("incompatible branch types"),
219 });
220 }
221 let els_type = els.typ(column_types);
222 let return_type = match return_type {
223 Some(prev) => prev.union(&els_type).expect("incompatible else type"),
224 None => els_type,
225 };
226 let sql_return_type = SqlColumnType::from_repr(&return_type);
227
228 let mut exprs = Vec::with_capacity(collected_cases.len() + 2);
230 exprs.push(common);
231 let mut lookup = BTreeMap::new();
232 for (row, result_expr) in collected_cases {
233 let idx = exprs.len();
234 lookup.insert(row, idx);
235 exprs.push(result_expr);
236 }
237 exprs.push(els);
238
239 *expr = MirScalarExpr::CallVariadic {
240 func: VariadicFunc::CaseLiteral(mz_expr::func::CaseLiteral {
241 lookup,
242 return_type: sql_return_type,
243 }),
244 exprs,
245 };
246}
247
248fn has_at_least_two_arms(expr: &MirScalarExpr) -> bool {
251 let mut count = 0;
252 let mut common_candidate: Option<&MirScalarExpr> = None;
253 let mut current = expr;
254
255 loop {
256 match current {
257 MirScalarExpr::If { cond, then: _, els } => {
258 if let Some((expr_side, _literal_row)) = peek_eq_literal(cond) {
259 match common_candidate {
260 None => {
261 common_candidate = Some(expr_side);
262 }
263 Some(existing) => {
264 if existing != expr_side {
265 break;
266 }
267 }
268 }
269 count += 1;
270 if count >= 2 {
271 return true;
272 }
273 current = els;
274 } else {
275 break;
276 }
277 }
278 _ => break,
279 }
280 }
281
282 false
283}
284
285fn peek_eq_literal(cond: &MirScalarExpr) -> Option<(&MirScalarExpr, &Row)> {
289 let MirScalarExpr::CallBinary {
290 func: BinaryFunc::Eq(_),
291 expr1,
292 expr2,
293 } = cond
294 else {
295 return None;
296 };
297
298 if let Some(row) = expr1.as_literal_non_null_row() {
299 if !expr2.is_literal() {
300 return Some((expr2.as_ref(), row));
301 }
302 }
303 if let Some(row) = expr2.as_literal_non_null_row() {
304 if !expr1.is_literal() {
305 return Some((expr1.as_ref(), row));
306 }
307 }
308 None
309}
310
311fn collect_if_chain_arms(
316 chain: MirScalarExpr,
317) -> (
318 Vec<(Row, MirScalarExpr)>,
319 Option<MirScalarExpr>,
320 MirScalarExpr,
321) {
322 let mut cases = Vec::new();
323 let mut seen = BTreeSet::new();
324 let mut common_candidate: Option<MirScalarExpr> = None;
325 let mut remaining = chain;
326
327 loop {
328 match remaining {
329 MirScalarExpr::If { cond, then, els } => {
330 if let Some((expr_side, literal_row)) = peek_eq_literal(&cond) {
331 match &common_candidate {
332 None => {
333 common_candidate = Some(expr_side.clone());
334 }
335 Some(existing) => {
336 if existing != expr_side {
337 remaining = MirScalarExpr::If { cond, then, els };
338 break;
339 }
340 }
341 }
342
343 if seen.insert(literal_row.clone()) {
345 cases.push((literal_row.clone(), *then));
346 }
347
348 remaining = *els;
349 } else {
350 remaining = MirScalarExpr::If { cond, then, els };
351 break;
352 }
353 }
354 _ => break,
355 }
356 }
357
358 (cases, common_candidate, remaining)
359}
360
361#[cfg(test)]
362mod tests {
363 use mz_expr::func::Eq;
364 use mz_expr::{MirRelationExpr, MirScalarExpr, VariadicFunc};
365 use mz_repr::{Datum, ReprColumnType, ReprRelationType, ReprScalarType};
366
367 use super::*;
368
369 fn lit_i64(v: i64) -> MirScalarExpr {
371 MirScalarExpr::literal_ok(Datum::Int64(v), ReprScalarType::Int64)
372 }
373
374 fn wrap_in_map(scalar: MirScalarExpr) -> MirRelationExpr {
376 MirRelationExpr::Map {
377 input: Box::new(MirRelationExpr::constant(
378 vec![vec![Datum::Int64(0)]],
379 ReprRelationType::new(vec![ReprColumnType {
380 scalar_type: ReprScalarType::Int64,
381 nullable: false,
382 }]),
383 )),
384 scalars: vec![scalar],
385 }
386 }
387
388 fn apply_transform(scalar: MirScalarExpr) -> MirScalarExpr {
390 let mut relation = wrap_in_map(scalar);
391 let mut features = mz_repr::optimize::OptimizerFeatures::default();
392 features.enable_case_literal_transform = true;
393 let typecheck_ctx = crate::typecheck::empty_typechecking_context();
394 let mut df_meta = crate::dataflow::DataflowMetainfo::default();
395 let mut transform_ctx =
396 crate::TransformCtx::local(&features, &typecheck_ctx, &mut df_meta, None, None);
397 crate::Transform::transform(&CaseLiteralTransform, &mut relation, &mut transform_ctx)
398 .unwrap();
399 match relation {
400 MirRelationExpr::Map { scalars, .. } => scalars.into_iter().next().unwrap(),
401 other => panic!("expected Map, got {other:?}"),
402 }
403 }
404
405 fn assert_case_literal(expr: &MirScalarExpr, expected_cases: usize) {
407 match expr {
408 MirScalarExpr::CallVariadic {
409 func: VariadicFunc::CaseLiteral(cl),
410 ..
411 } => {
412 assert_eq!(
413 cl.lookup.len(),
414 expected_cases,
415 "expected {expected_cases} cases, got {}",
416 cl.lookup.len()
417 );
418 }
419 other => panic!("expected CaseLiteral, got {other:?}"),
420 }
421 }
422
423 fn build_2_arm_chain() -> MirScalarExpr {
425 MirScalarExpr::column(0)
426 .call_binary(lit_i64(1), Eq)
427 .if_then_else(
428 lit_i64(10),
429 MirScalarExpr::column(0)
430 .call_binary(lit_i64(2), Eq)
431 .if_then_else(lit_i64(20), lit_i64(0)),
432 )
433 }
434
435 #[mz_ore::test]
436 #[cfg_attr(miri, ignore)] fn test_null_literal_skipped() {
438 let null_lit = MirScalarExpr::literal(Ok(Datum::Null), ReprScalarType::Int64);
442 let expr = MirScalarExpr::column(0)
443 .call_binary(null_lit, Eq)
444 .if_then_else(
445 lit_i64(10),
446 MirScalarExpr::column(0)
447 .call_binary(lit_i64(2), Eq)
448 .if_then_else(
449 lit_i64(20),
450 MirScalarExpr::column(0)
451 .call_binary(lit_i64(3), Eq)
452 .if_then_else(lit_i64(30), lit_i64(0)),
453 ),
454 );
455 let result = apply_transform(expr);
456 match &result {
464 MirScalarExpr::If { els, .. } => {
465 assert_case_literal(els, 2);
466 }
467 other => panic!("expected If with CaseLiteral in els, got {other:?}"),
468 }
469 }
470
471 #[mz_ore::test]
472 #[cfg_attr(miri, ignore)] fn test_64_arm_chain() {
474 let n: usize = 64;
476 let mut expr = lit_i64(-1);
477 for i in (0..n).rev() {
478 let i = i64::try_from(i).expect("arm index fits in i64");
479 expr = MirScalarExpr::column(0)
480 .call_binary(lit_i64(i), Eq)
481 .if_then_else(lit_i64(100 * i), expr);
482 }
483 let result = apply_transform(expr);
484 assert_case_literal(&result, n);
485
486 let arena = mz_repr::RowArena::new();
488 assert_eq!(
489 result.eval(&[Datum::Int64(0)], &arena).unwrap(),
490 Datum::Int64(0)
491 );
492 assert_eq!(
493 result.eval(&[Datum::Int64(32)], &arena).unwrap(),
494 Datum::Int64(3200)
495 );
496 assert_eq!(
497 result.eval(&[Datum::Int64(63)], &arena).unwrap(),
498 Datum::Int64(6300)
499 );
500 assert_eq!(
501 result.eval(&[Datum::Int64(999)], &arena).unwrap(),
502 Datum::Int64(-1)
503 );
504 }
505
506 #[mz_ore::test]
507 #[cfg_attr(miri, ignore)] fn test_eval_basic() {
509 let expr = build_2_arm_chain();
511 let result = apply_transform(expr);
512
513 let arena = mz_repr::RowArena::new();
514
515 let out = result.eval(&[Datum::Int64(1)], &arena).unwrap();
517 assert_eq!(out, Datum::Int64(10));
518
519 let out = result.eval(&[Datum::Int64(2)], &arena).unwrap();
521 assert_eq!(out, Datum::Int64(20));
522
523 let out = result.eval(&[Datum::Int64(99)], &arena).unwrap();
525 assert_eq!(out, Datum::Int64(0));
526
527 let out = result.eval(&[Datum::Null], &arena).unwrap();
529 assert_eq!(out, Datum::Int64(0));
530 }
531
532 #[mz_ore::test]
533 #[cfg_attr(miri, ignore)] fn test_string_literals() {
535 fn lit_str(s: &str) -> MirScalarExpr {
538 MirScalarExpr::literal_ok(Datum::String(s), ReprScalarType::String)
539 }
540 fn wrap_in_string_map(scalar: MirScalarExpr) -> MirRelationExpr {
541 MirRelationExpr::Map {
542 input: Box::new(MirRelationExpr::constant(
543 vec![vec![Datum::String("x")]],
544 ReprRelationType::new(vec![ReprColumnType {
545 scalar_type: ReprScalarType::String,
546 nullable: false,
547 }]),
548 )),
549 scalars: vec![scalar],
550 }
551 }
552 let expr = MirScalarExpr::column(0)
553 .call_binary(lit_str("a"), Eq)
554 .if_then_else(
555 lit_i64(1),
556 MirScalarExpr::column(0)
557 .call_binary(lit_str("b"), Eq)
558 .if_then_else(lit_i64(2), lit_i64(0)),
559 );
560 let mut relation = wrap_in_string_map(expr);
561 let mut features = mz_repr::optimize::OptimizerFeatures::default();
562 features.enable_case_literal_transform = true;
563 let typecheck_ctx = crate::typecheck::empty_typechecking_context();
564 let mut df_meta = crate::dataflow::DataflowMetainfo::default();
565 let mut transform_ctx =
566 crate::TransformCtx::local(&features, &typecheck_ctx, &mut df_meta, None, None);
567 crate::Transform::transform(&CaseLiteralTransform, &mut relation, &mut transform_ctx)
568 .unwrap();
569 let result = match relation {
570 MirRelationExpr::Map { scalars, .. } => scalars.into_iter().next().unwrap(),
571 other => panic!("expected Map, got {other:?}"),
572 };
573 assert_case_literal(&result, 2);
574
575 let arena = mz_repr::RowArena::new();
576 assert_eq!(
577 result.eval(&[Datum::String("a")], &arena).unwrap(),
578 Datum::Int64(1)
579 );
580 assert_eq!(
581 result.eval(&[Datum::String("b")], &arena).unwrap(),
582 Datum::Int64(2)
583 );
584 assert_eq!(
585 result.eval(&[Datum::String("z")], &arena).unwrap(),
586 Datum::Int64(0)
587 );
588 }
589}