mz_expr/relation.rs
1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10#![warn(missing_docs)]
11
12use std::cmp::{Ordering, max};
13use std::collections::{BTreeMap, BTreeSet};
14use std::fmt;
15use std::fmt::{Display, Formatter};
16use std::hash::{DefaultHasher, Hash, Hasher};
17use std::num::NonZeroU64;
18use std::time::Instant;
19
20use bytesize::ByteSize;
21use differential_dataflow::containers::{Columnation, CopyRegion};
22use itertools::Itertools;
23use mz_lowertest::MzReflect;
24use mz_ore::cast::CastFrom;
25use mz_ore::collections::CollectionExt;
26use mz_ore::id_gen::IdGen;
27use mz_ore::metrics::Histogram;
28use mz_ore::num::NonNeg;
29use mz_ore::soft_assert_no_log;
30use mz_ore::stack::RecursionLimitError;
31use mz_ore::str::Indent;
32use mz_repr::adt::numeric::NumericMaxScale;
33use mz_repr::explain::text::text_string_at;
34use mz_repr::explain::{
35 DummyHumanizer, ExplainConfig, ExprHumanizer, IndexUsageType, PlanRenderingContext,
36};
37use mz_repr::{
38 ColumnName, Datum, Diff, GlobalId, IntoRowIterator, ReprColumnType, ReprRelationType,
39 ReprScalarType, Row, RowIterator, SqlColumnType, SqlRelationType, SqlScalarType,
40};
41use serde::{Deserialize, Serialize};
42
43use crate::Id::Local;
44use crate::explain::{HumanizedExpr, HumanizerMode};
45use crate::relation::func::{AggregateFunc, LagLeadType, TableFunc};
46use crate::row::{RowCollection, RowCollectionIter};
47use crate::scalar::func::variadic::{
48 JsonbBuildArray, JsonbBuildObject, ListCreate, ListIndex, MapBuild, RecordCreate,
49};
50use crate::visit::{Visit, VisitChildren};
51use crate::{
52 EvalError, FilterCharacteristics, Id, LocalId, MirScalarExpr, UnaryFunc, func as scalar_func,
53};
54
55pub mod canonicalize;
56pub mod func;
57pub mod join_input_mapper;
58
59/// A recursion limit to be used for stack-safe traversals of [`MirRelationExpr`] trees.
60///
61/// The recursion limit must be large enough to accommodate for the linear representation
62/// of some pathological but frequently occurring query fragments.
63///
64/// For example, in MIR we could have long chains of
65/// - (1) `Let` bindings,
66/// - (2) `CallBinary` calls with associative functions such as `+`
67///
68/// Until we fix those, we need to stick with the larger recursion limit.
69pub const RECURSION_LIMIT: usize = 2048;
70
71/// A trait for types that describe how to build a collection.
72pub trait CollectionPlan {
73 /// Collects the set of global identifiers from dataflows referenced in Get.
74 fn depends_on_into(&self, out: &mut BTreeSet<GlobalId>);
75
76 /// Returns the set of global identifiers from dataflows referenced in Get.
77 ///
78 /// See [`CollectionPlan::depends_on_into`] to reuse an existing `BTreeSet`.
79 fn depends_on(&self) -> BTreeSet<GlobalId> {
80 let mut out = BTreeSet::new();
81 self.depends_on_into(&mut out);
82 out
83 }
84}
85
86/// An abstract syntax tree which defines a collection.
87///
88/// The AST is meant to reflect the capabilities of the `differential_dataflow::Collection` type,
89/// written generically enough to avoid run-time compilation work.
90///
91/// `derived_hash_with_manual_eq` was complaining for the wrong reason: This lint exists because
92/// it's bad when `Eq` doesn't agree with `Hash`, which is often quite likely if one of them is
93/// implemented manually. However, our manual implementation of `Eq` _will_ agree with the derived
94/// one. This is because the reason for the manual implementation is not to change the semantics
95/// from the derived one, but to avoid stack overflows.
96#[allow(clippy::derived_hash_with_manual_eq)]
97#[derive(Clone, Debug, Ord, PartialOrd, Serialize, Deserialize, MzReflect, Hash)]
98pub enum MirRelationExpr {
99 /// A constant relation containing specified rows.
100 ///
101 /// The runtime memory footprint of this operator is zero.
102 ///
103 /// When you would like to pattern match on this, consider using `MirRelationExpr::as_const`
104 /// instead, which looks behind `ArrangeBy`s. You might want this matching behavior because
105 /// constant folding doesn't remove `ArrangeBy`s.
106 Constant {
107 /// Rows of the constant collection and their multiplicities.
108 rows: Result<Vec<(Row, Diff)>, EvalError>,
109 /// Schema of the collection.
110 typ: ReprRelationType,
111 },
112 /// Get an existing dataflow.
113 ///
114 /// The runtime memory footprint of this operator is zero.
115 Get {
116 /// The identifier for the collection to load.
117 #[mzreflect(ignore)]
118 id: Id,
119 /// Schema of the collection.
120 typ: ReprRelationType,
121 /// If this is a global Get, this will indicate whether we are going to read from Persist or
122 /// from an index, or from a different object in `objects_to_build`. If it's an index, then
123 /// how downstream dataflow operations will use this index is also recorded. This is filled
124 /// by `prune_and_annotate_dataflow_index_imports`. Note that this is not used by the
125 /// lowering to LIR, but is used only by EXPLAIN.
126 #[mzreflect(ignore)]
127 access_strategy: AccessStrategy,
128 },
129 /// Introduce a temporary dataflow.
130 ///
131 /// The runtime memory footprint of this operator is zero.
132 Let {
133 /// The identifier to be used in `Get` variants to retrieve `value`.
134 #[mzreflect(ignore)]
135 id: LocalId,
136 /// The collection to be bound to `id`.
137 value: Box<MirRelationExpr>,
138 /// The result of the `Let`, evaluated with `id` bound to `value`.
139 body: Box<MirRelationExpr>,
140 },
141 /// Introduce mutually recursive bindings.
142 ///
143 /// Each `LocalId` is immediately bound to an initially empty collection
144 /// with the type of its corresponding `MirRelationExpr`. Repeatedly, each
145 /// binding is evaluated using the current contents of each other binding,
146 /// and is refreshed to contain the new evaluation. This process continues
147 /// through all bindings, and repeats as long as changes continue to occur.
148 ///
149 /// The resulting value of the expression is `body` evaluated once in the
150 /// context of the final iterates.
151 ///
152 /// A zero-binding instance can be replaced by `body`.
153 /// A single-binding instance is equivalent to `MirRelationExpr::Let`.
154 ///
155 /// The runtime memory footprint of this operator is zero.
156 LetRec {
157 /// The identifiers to be used in `Get` variants to retrieve each `value`.
158 #[mzreflect(ignore)]
159 ids: Vec<LocalId>,
160 /// The collections to be bound to each `id`.
161 values: Vec<MirRelationExpr>,
162 /// Maximum number of iterations, after which we should artificially force a fixpoint.
163 /// (Whether we error or just stop is configured by `LetRecLimit::return_at_limit`.)
164 /// The per-`LetRec` limit that the user specified is initially copied to each binding to
165 /// accommodate slicing and merging of `LetRec`s in MIR transforms (e.g., `NormalizeLets`).
166 #[mzreflect(ignore)]
167 limits: Vec<Option<LetRecLimit>>,
168 /// The result of the `Let`, evaluated with `id` bound to `value`.
169 body: Box<MirRelationExpr>,
170 },
171 /// Project out some columns from a dataflow
172 ///
173 /// The runtime memory footprint of this operator is zero.
174 Project {
175 /// The source collection.
176 input: Box<MirRelationExpr>,
177 /// Indices of columns to retain.
178 outputs: Vec<usize>,
179 },
180 /// Append new columns to a dataflow
181 ///
182 /// The runtime memory footprint of this operator is zero.
183 Map {
184 /// The source collection.
185 input: Box<MirRelationExpr>,
186 /// Expressions which determine values to append to each row.
187 /// An expression may refer to columns in `input` or
188 /// expressions defined earlier in the vector
189 scalars: Vec<MirScalarExpr>,
190 },
191 /// Like Map, but yields zero-or-more output rows per input row
192 ///
193 /// The runtime memory footprint of this operator is zero.
194 FlatMap {
195 /// The source collection
196 input: Box<MirRelationExpr>,
197 /// The table func to apply
198 func: TableFunc,
199 /// The argument to the table func
200 exprs: Vec<MirScalarExpr>,
201 },
202 /// Keep rows from a dataflow where all the predicates are true
203 ///
204 /// The runtime memory footprint of this operator is zero.
205 Filter {
206 /// The source collection.
207 input: Box<MirRelationExpr>,
208 /// Predicates, each of which must be true.
209 predicates: Vec<MirScalarExpr>,
210 },
211 /// Join several collections, where some columns must be equal.
212 ///
213 /// For further details consult the documentation for [`MirRelationExpr::join`].
214 ///
215 /// The runtime memory footprint of this operator can be proportional to
216 /// the sizes of all inputs and the size of all joins of prefixes.
217 /// This may be reduced due to arrangements available at rendering time.
218 Join {
219 /// A sequence of input relations.
220 inputs: Vec<MirRelationExpr>,
221 /// A sequence of equivalence classes of expressions on the cross product of inputs.
222 ///
223 /// Each equivalence class is a list of scalar expressions, where for each class the
224 /// intended interpretation is that all evaluated expressions should be equal.
225 ///
226 /// Each scalar expression is to be evaluated over the cross-product of all records
227 /// from all inputs. In many cases this may just be column selection from specific
228 /// inputs, but more general cases exist (e.g. complex functions of multiple columns
229 /// from multiple inputs, or just constant literals).
230 equivalences: Vec<Vec<MirScalarExpr>>,
231 /// Join implementation information.
232 #[serde(default)]
233 implementation: JoinImplementation,
234 },
235 /// Group a dataflow by some columns and aggregate over each group
236 ///
237 /// The runtime memory footprint of this operator is at most proportional to the
238 /// number of distinct records in the input and output. The actual requirements
239 /// can be less: the number of distinct inputs to each aggregate, summed across
240 /// each aggregate, plus the output size. For more details consult the code that
241 /// builds the associated dataflow.
242 Reduce {
243 /// The source collection.
244 input: Box<MirRelationExpr>,
245 /// Column indices used to form groups.
246 group_key: Vec<MirScalarExpr>,
247 /// Expressions which determine values to append to each row, after the group keys.
248 aggregates: Vec<AggregateExpr>,
249 /// True iff the input is known to monotonically increase (only addition of records).
250 #[serde(default)]
251 monotonic: bool,
252 /// User hint: expected number of values per group key. Used to optimize physical rendering.
253 #[serde(default)]
254 expected_group_size: Option<u64>,
255 },
256 /// Groups and orders within each group, limiting output.
257 ///
258 /// The runtime memory footprint of this operator is proportional to its input and output.
259 TopK {
260 /// The source collection.
261 input: Box<MirRelationExpr>,
262 /// Column indices used to form groups.
263 group_key: Vec<usize>,
264 /// Column indices used to order rows within groups.
265 order_key: Vec<ColumnOrder>,
266 /// Number of records to retain
267 #[serde(default)]
268 limit: Option<MirScalarExpr>,
269 /// Number of records to skip
270 #[serde(default)]
271 offset: usize,
272 /// True iff the input is known to monotonically increase (only addition of records).
273 #[serde(default)]
274 monotonic: bool,
275 /// User-supplied hint: how many rows will have the same group key.
276 #[serde(default)]
277 expected_group_size: Option<u64>,
278 },
279 /// Return a dataflow where the row counts are negated
280 ///
281 /// The runtime memory footprint of this operator is zero.
282 Negate {
283 /// The source collection.
284 input: Box<MirRelationExpr>,
285 },
286 /// Keep rows from a dataflow where the row counts are positive
287 ///
288 /// The runtime memory footprint of this operator is proportional to its input and output.
289 Threshold {
290 /// The source collection.
291 input: Box<MirRelationExpr>,
292 },
293 /// Adds the frequencies of elements in contained sets.
294 ///
295 /// The runtime memory footprint of this operator is zero.
296 Union {
297 /// A source collection.
298 base: Box<MirRelationExpr>,
299 /// Source collections to union.
300 inputs: Vec<MirRelationExpr>,
301 },
302 /// Technically a no-op. Used to render an index. Will be used to optimize queries
303 /// on finer grain. Each `keys` item represents a different index that should be
304 /// produced from the `keys`.
305 ///
306 /// The runtime memory footprint of this operator is proportional to its input.
307 ArrangeBy {
308 /// The source collection
309 input: Box<MirRelationExpr>,
310 /// Columns to arrange `input` by, in order of decreasing primacy
311 keys: Vec<Vec<MirScalarExpr>>,
312 },
313}
314
315impl PartialEq for MirRelationExpr {
316 fn eq(&self, other: &Self) -> bool {
317 // Capture the result and test it wrt `Ord` implementation in test environments.
318 let result = structured_diff::MreDiff::new(self, other).next().is_none();
319 mz_ore::soft_assert_eq_no_log!(result, self.cmp(other) == Ordering::Equal);
320 result
321 }
322}
323impl Eq for MirRelationExpr {}
324
325impl MirRelationExpr {
326 /// Reports the schema of the relation.
327 ///
328 /// This is the SQL-type parallel of [`Self::typ`]; it is merely
329 /// a wrapper around it, returning a [`SqlRelationType`] instead of
330 /// a [`ReprRelationType`].
331 pub fn sql_typ(&self) -> SqlRelationType {
332 let repr_typ = self.typ();
333 SqlRelationType::from_repr(&repr_typ)
334 }
335
336 /// Reports the repr schema of the relation.
337 ///
338 /// This method determines the type through recursive traversal of the
339 /// relation expression, drawing from the types of base collections.
340 /// As such, this is not an especially cheap method, and should be used
341 /// judiciously.
342 ///
343 /// The relation type is computed incrementally with a recursive post-order
344 /// traversal, that accumulates the input types for the relations yet to be
345 /// visited in `type_stack`.
346 pub fn typ(&self) -> ReprRelationType {
347 let mut type_stack = Vec::new();
348 #[allow(deprecated)]
349 self.visit_pre_post_nolimit(
350 &mut |e: &MirRelationExpr| -> Option<Vec<&MirRelationExpr>> {
351 match &e {
352 MirRelationExpr::Let { body, .. } => Some(vec![&*body]),
353 MirRelationExpr::LetRec { body, .. } => Some(vec![&*body]),
354 _ => None,
355 }
356 },
357 &mut |e: &MirRelationExpr| {
358 match e {
359 MirRelationExpr::Let { .. } => {
360 let body_typ = type_stack.pop().unwrap();
361 // Insert a dummy relation type for the value, since `typ_with_input_types`
362 // won't look at it, but expects the relation type of the body to be second.
363 type_stack.push(ReprRelationType::empty());
364 type_stack.push(body_typ);
365 }
366 MirRelationExpr::LetRec { values, .. } => {
367 let body_typ = type_stack.pop().unwrap();
368 type_stack.extend(
369 std::iter::repeat(ReprRelationType::empty()).take(values.len()),
370 );
371 // Insert dummy relation types for the values, since `typ_with_input_types`
372 // won't look at them, but expects the relation type of the body to be last.
373 type_stack.push(body_typ);
374 }
375 _ => {}
376 }
377 let num_inputs = e.num_inputs();
378 let relation_type =
379 e.typ_with_input_types(&type_stack[type_stack.len() - num_inputs..]);
380 type_stack.truncate(type_stack.len() - num_inputs);
381 type_stack.push(relation_type);
382 },
383 );
384 assert_eq!(type_stack.len(), 1);
385 type_stack.pop().unwrap()
386 }
387
388 /// Reports the repr schema of the relation given the repr schema of the input relations.
389 pub fn typ_with_input_types(&self, input_types: &[ReprRelationType]) -> ReprRelationType {
390 let column_types = self.col_with_input_cols(input_types.iter().map(|i| &i.column_types));
391 let unique_keys = self.keys_with_input_keys(
392 input_types.iter().map(|i| i.arity()),
393 input_types.iter().map(|i| &i.keys),
394 );
395 ReprRelationType::new(column_types).with_keys(unique_keys)
396 }
397
398 /// Reports the column types of the relation given the column types of the
399 /// input relations.
400 ///
401 /// This method delegates to `try_col_with_input_cols`, panicking if an `Err`
402 /// variant is returned.
403 pub fn col_with_input_cols<'a, I>(&self, input_types: I) -> Vec<ReprColumnType>
404 where
405 I: Iterator<Item = &'a Vec<ReprColumnType>>,
406 {
407 match self.try_col_with_input_cols(input_types) {
408 Ok(col_types) => col_types,
409 Err(err) => panic!("{err}"),
410 }
411 }
412
413 /// Reports the column types of the relation given the column types of the input relations.
414 ///
415 /// `input_types` is required to contain the column types for the input relations of
416 /// the current relation in the same order as they are visited by `try_visit_children`
417 /// method, even though not all may be used for computing the schema of the
418 /// current relation. For example, `Let` expects two input types, one for the
419 /// value relation and one for the body, in that order, but only the one for the
420 /// body is used to determine the type of the `Let` relation.
421 ///
422 /// It is meant to be used during post-order traversals to compute column types
423 /// incrementally.
424 pub fn try_col_with_input_cols<'a, I>(
425 &self,
426 mut input_types: I,
427 ) -> Result<Vec<ReprColumnType>, String>
428 where
429 I: Iterator<Item = &'a Vec<ReprColumnType>>,
430 {
431 use MirRelationExpr::*;
432
433 let col_types = match self {
434 Constant { rows, typ } => {
435 let mut col_types = typ.column_types.clone();
436 let mut seen_null = vec![false; typ.arity()];
437 if let Ok(rows) = rows {
438 for (row, _diff) in rows {
439 for (datum, i) in row.iter().zip_eq(0..typ.arity()) {
440 if datum.is_null() {
441 seen_null[i] = true;
442 }
443 }
444 }
445 }
446 for (&seen_null, i) in seen_null.iter().zip_eq(0..typ.arity()) {
447 if !seen_null {
448 col_types[i].nullable = false;
449 } else {
450 assert!(col_types[i].nullable);
451 }
452 }
453 col_types
454 }
455 Get { typ, .. } => typ.column_types.clone(),
456 Project { outputs, .. } => {
457 let input = input_types.next().unwrap();
458 outputs.iter().map(|&i| input[i].clone()).collect()
459 }
460 Map { scalars, .. } => {
461 let mut result = input_types.next().unwrap().clone();
462 for scalar in scalars.iter() {
463 result.push(scalar.typ(&result))
464 }
465 result
466 }
467 FlatMap { func, .. } => {
468 let mut result = input_types.next().unwrap().clone();
469 result.extend(
470 func.output_sql_type()
471 .column_types
472 .iter()
473 .map(ReprColumnType::from),
474 );
475 result
476 }
477 Filter { predicates, .. } => {
478 let mut result = input_types.next().unwrap().clone();
479
480 // Set as nonnull any columns where null values would cause
481 // any predicate to evaluate to null.
482 for column in non_nullable_columns(predicates) {
483 result[column].nullable = false;
484 }
485 result
486 }
487 Join { equivalences, .. } => {
488 // Concatenate input column types
489 let mut types = input_types.flat_map(|cols| cols.to_owned()).collect_vec();
490 // In an equivalence class, if any column is non-null, then make all non-null
491 for equivalence in equivalences {
492 let col_inds = equivalence
493 .iter()
494 .filter_map(|expr| match expr {
495 MirScalarExpr::Column(col, _name) => Some(*col),
496 _ => None,
497 })
498 .collect_vec();
499 if col_inds.iter().any(|i| !types.get(*i).unwrap().nullable) {
500 for i in col_inds {
501 types.get_mut(i).unwrap().nullable = false;
502 }
503 }
504 }
505 types
506 }
507 Reduce {
508 group_key,
509 aggregates,
510 ..
511 } => {
512 let input = input_types.next().unwrap();
513 group_key
514 .iter()
515 .map(|e| e.typ(input))
516 .chain(aggregates.iter().map(|agg| agg.typ(input)))
517 .collect()
518 }
519 TopK { .. } | Negate { .. } | Threshold { .. } | ArrangeBy { .. } => {
520 input_types.next().unwrap().clone()
521 }
522 Let { .. } => {
523 // skip over the input types for `value`.
524 input_types.nth(1).unwrap().clone()
525 }
526 LetRec { values, .. } => {
527 // skip over the input types for `values`.
528 input_types.nth(values.len()).unwrap().clone()
529 }
530 Union { .. } => {
531 let mut result = input_types.next().unwrap().clone();
532 for input_col_types in input_types {
533 for (base_col, col) in result.iter_mut().zip_eq(input_col_types) {
534 *base_col = base_col
535 .union(col)
536 .map_err(|e| format!("{}\nin plan:\n{}", e, self.pretty()))?;
537 }
538 }
539 result
540 }
541 };
542
543 Ok(col_types)
544 }
545
546 /// Reports the unique keys of the relation given the arities and the unique
547 /// keys of the input relations.
548 ///
549 /// `input_arities` and `input_keys` are required to contain the
550 /// corresponding info for the input relations of
551 /// the current relation in the same order as they are visited by `try_visit_children`
552 /// method, even though not all may be used for computing the schema of the
553 /// current relation. For example, `Let` expects two input types, one for the
554 /// value relation and one for the body, in that order, but only the one for the
555 /// body is used to determine the type of the `Let` relation.
556 ///
557 /// It is meant to be used during post-order traversals to compute unique keys
558 /// incrementally.
559 pub fn keys_with_input_keys<'a, I, J>(
560 &self,
561 mut input_arities: I,
562 mut input_keys: J,
563 ) -> Vec<Vec<usize>>
564 where
565 I: Iterator<Item = usize>,
566 J: Iterator<Item = &'a Vec<Vec<usize>>>,
567 {
568 use MirRelationExpr::*;
569
570 let mut keys = match self {
571 Constant {
572 rows: Ok(rows),
573 typ,
574 } => {
575 let n_cols = typ.arity();
576 // If the `i`th entry is `Some`, then we have not yet observed non-uniqueness in the `i`th column.
577 let mut unique_values_per_col = vec![Some(BTreeSet::<Datum>::default()); n_cols];
578 for (row, diff) in rows {
579 for (i, datum) in row.iter().enumerate() {
580 if datum != Datum::Dummy {
581 if let Some(unique_vals) = &mut unique_values_per_col[i] {
582 let is_dupe = *diff != Diff::ONE || !unique_vals.insert(datum);
583 if is_dupe {
584 unique_values_per_col[i] = None;
585 }
586 }
587 }
588 }
589 }
590 if rows.len() == 0 || (rows.len() == 1 && rows[0].1 == Diff::ONE) {
591 vec![vec![]]
592 } else {
593 // XXX - Multi-column keys are not detected.
594 typ.keys
595 .iter()
596 .cloned()
597 .chain(
598 unique_values_per_col
599 .into_iter()
600 .enumerate()
601 .filter(|(_idx, unique_vals)| unique_vals.is_some())
602 .map(|(idx, _)| vec![idx]),
603 )
604 .collect()
605 }
606 }
607 Constant { rows: Err(_), typ } | Get { typ, .. } => typ.keys.clone(),
608 Threshold { .. } | ArrangeBy { .. } => input_keys.next().unwrap().clone(),
609 Let { .. } => {
610 // skip over the unique keys for value
611 input_keys.nth(1).unwrap().clone()
612 }
613 LetRec { values, .. } => {
614 // skip over the unique keys for value
615 input_keys.nth(values.len()).unwrap().clone()
616 }
617 Project { outputs, .. } => {
618 let input = input_keys.next().unwrap();
619 input
620 .iter()
621 .filter_map(|key_set| {
622 if key_set.iter().all(|k| outputs.contains(k)) {
623 Some(
624 key_set
625 .iter()
626 .map(|c| outputs.iter().position(|o| o == c).unwrap())
627 .collect(),
628 )
629 } else {
630 None
631 }
632 })
633 .collect()
634 }
635 Map { scalars, .. } => {
636 let mut remappings = Vec::new();
637 let arity = input_arities.next().unwrap();
638 for (column, scalar) in scalars.iter().enumerate() {
639 // assess whether the scalar preserves uniqueness,
640 // and could participate in a key!
641
642 fn uniqueness(expr: &MirScalarExpr) -> Option<usize> {
643 match expr {
644 MirScalarExpr::CallUnary { func, expr } => {
645 if func.preserves_uniqueness() {
646 uniqueness(expr)
647 } else {
648 None
649 }
650 }
651 MirScalarExpr::Column(c, _name) => Some(*c),
652 _ => None,
653 }
654 }
655
656 if let Some(c) = uniqueness(scalar) {
657 remappings.push((c, column + arity));
658 }
659 }
660
661 let mut result = input_keys.next().unwrap().clone();
662 let mut new_keys = Vec::new();
663 // Any column in `remappings` could be replaced in a key
664 // by the corresponding c. This could lead to combinatorial
665 // explosion using our current representation, so we wont
666 // do that. Instead, we'll handle the case of one remapping.
667 if remappings.len() == 1 {
668 let (old, new) = remappings.pop().unwrap();
669 for key in &result {
670 if key.contains(&old) {
671 let mut new_key: Vec<usize> =
672 key.iter().cloned().filter(|k| k != &old).collect();
673 new_key.push(new);
674 new_key.sort_unstable();
675 new_keys.push(new_key);
676 }
677 }
678 result.append(&mut new_keys);
679 }
680 result
681 }
682 FlatMap { .. } => {
683 // FlatMap can add duplicate rows, so input keys are no longer
684 // valid
685 vec![]
686 }
687 Negate { .. } => {
688 // Although negate may have distinct records for each key,
689 // the multiplicity is -1 rather than 1. This breaks many
690 // of the optimization uses of "keys".
691 vec![]
692 }
693 Filter { predicates, .. } => {
694 // A filter inherits the keys of its input unless the filters
695 // have reduced the input to a single row, in which case the
696 // keys of the input are `()`.
697 let mut input = input_keys.next().unwrap().clone();
698
699 if !input.is_empty() {
700 // Track columns equated to literals, which we can prune.
701 let mut cols_equal_to_literal = BTreeSet::new();
702
703 // Perform union find on `col1 = col2` to establish
704 // connected components of equated columns. Absent any
705 // equalities, this will be `0 .. #c` (where #c is the
706 // greatest column referenced by a predicate), but each
707 // equality will orient the root of the greater to the root
708 // of the lesser.
709 let mut union_find = Vec::new();
710
711 for expr in predicates.iter() {
712 if let MirScalarExpr::CallBinary {
713 func: crate::BinaryFunc::Eq(_),
714 expr1,
715 expr2,
716 } = expr
717 {
718 if let MirScalarExpr::Column(c, _name) = &**expr1 {
719 if expr2.is_literal_ok() {
720 cols_equal_to_literal.insert(c);
721 }
722 }
723 if let MirScalarExpr::Column(c, _name) = &**expr2 {
724 if expr1.is_literal_ok() {
725 cols_equal_to_literal.insert(c);
726 }
727 }
728 // Perform union-find to equate columns.
729 if let (Some(c1), Some(c2)) = (expr1.as_column(), expr2.as_column()) {
730 if c1 != c2 {
731 // Ensure union_find has entries up to
732 // max(c1, c2) by filling up missing
733 // positions with identity mappings.
734 while union_find.len() <= std::cmp::max(c1, c2) {
735 union_find.push(union_find.len());
736 }
737 let mut r1 = c1; // Find the representative column of [c1].
738 while r1 != union_find[r1] {
739 assert!(union_find[r1] < r1);
740 r1 = union_find[r1];
741 }
742 let mut r2 = c2; // Find the representative column of [c2].
743 while r2 != union_find[r2] {
744 assert!(union_find[r2] < r2);
745 r2 = union_find[r2];
746 }
747 // Union [c1] and [c2] by pointing the
748 // larger to the smaller representative (we
749 // update the remaining equivalence class
750 // members only once after this for-loop).
751 union_find[std::cmp::max(r1, r2)] = std::cmp::min(r1, r2);
752 }
753 }
754 }
755 }
756
757 // Complete union-find by pointing each element at its representative column.
758 for i in 0..union_find.len() {
759 // Iteration not required, as each prior already references the right column.
760 union_find[i] = union_find[union_find[i]];
761 }
762
763 // Remove columns bound to literals, and remap columns equated to earlier columns.
764 // We will re-expand remapped columns in a moment, but this avoids exponential work.
765 for key_set in &mut input {
766 key_set.retain(|k| !cols_equal_to_literal.contains(&k));
767 for col in key_set.iter_mut() {
768 if let Some(equiv) = union_find.get(*col) {
769 *col = *equiv;
770 }
771 }
772 key_set.sort();
773 key_set.dedup();
774 }
775 input.sort();
776 input.dedup();
777
778 // Expand out each key to each of its equivalent forms.
779 // Each instance of `col` can be replaced by any equivalent column.
780 // This has the potential to result in exponentially sized number of unique keys,
781 // and in the future we should probably maintain unique keys modulo equivalence.
782
783 // First, compute an inverse map from each representative
784 // column `sub` to all other equivalent columns `col`.
785 let mut subs = Vec::new();
786 for (col, sub) in union_find.iter().enumerate() {
787 if *sub != col {
788 assert!(*sub < col);
789 while subs.len() <= *sub {
790 subs.push(Vec::new());
791 }
792 subs[*sub].push(col);
793 }
794 }
795 // For each column, substitute for it in each occurrence.
796 let mut to_add = Vec::new();
797 for (col, subs) in subs.iter().enumerate() {
798 if !subs.is_empty() {
799 for key_set in input.iter() {
800 if key_set.contains(&col) {
801 let mut to_extend = key_set.clone();
802 to_extend.retain(|c| c != &col);
803 for sub in subs {
804 to_extend.push(*sub);
805 to_add.push(to_extend.clone());
806 to_extend.pop();
807 }
808 }
809 }
810 }
811 // No deduplication, as we cannot introduce duplicates.
812 input.append(&mut to_add);
813 }
814 for key_set in input.iter_mut() {
815 key_set.sort();
816 key_set.dedup();
817 }
818 }
819 input
820 }
821 Join { equivalences, .. } => {
822 // It is important the `new_from_input_arities` constructor is
823 // used. Otherwise, Materialize may potentially end up in an
824 // infinite loop.
825 let input_mapper = crate::JoinInputMapper::new_from_input_arities(input_arities);
826
827 input_mapper.global_keys(input_keys, equivalences)
828 }
829 Reduce { group_key, .. } => {
830 // The group key should form a key, but we might already have
831 // keys that are subsets of the group key, and should retain
832 // those instead, if so.
833 let mut result = Vec::new();
834 for key_set in input_keys.next().unwrap() {
835 if key_set
836 .iter()
837 .all(|k| group_key.contains(&MirScalarExpr::column(*k)))
838 {
839 result.push(
840 key_set
841 .iter()
842 .map(|i| {
843 group_key
844 .iter()
845 .position(|k| k == &MirScalarExpr::column(*i))
846 .unwrap()
847 })
848 .collect::<Vec<_>>(),
849 );
850 }
851 }
852 if result.is_empty() {
853 result.push((0..group_key.len()).collect());
854 }
855 result
856 }
857 TopK {
858 group_key, limit, ..
859 } => {
860 // If `limit` is `Some(1)` then the group key will become
861 // a unique key, as there will be only one record with that key.
862 let mut result = input_keys.next().unwrap().clone();
863 if limit.as_ref().and_then(|x| x.as_literal_int64()) == Some(1) {
864 result.push(group_key.clone())
865 }
866 result
867 }
868 Union { base, inputs } => {
869 // Generally, unions do not have any unique keys, because
870 // each input might duplicate some. However, there is at
871 // least one idiomatic structure that does preserve keys,
872 // which results from SQL aggregations that must populate
873 // absent records with default values. In that pattern,
874 // the union of one GET with its negation, which has first
875 // been subjected to a projection and map, we can remove
876 // their influence on the key structure.
877 //
878 // If there are A, B, each with a unique `key` such that
879 // we are looking at
880 //
881 // A.proj(set_containing_key) + (B - A.proj(key)).map(stuff)
882 //
883 // Then we can report `key` as a unique key.
884 //
885 // TODO: make unique key structure an optimization analysis
886 // rather than part of the type information.
887 // TODO: perhaps ensure that (above) A.proj(key) is a
888 // subset of B, as otherwise there are negative records
889 // and who knows what is true (not expected, but again
890 // who knows what the query plan might look like).
891
892 let arity = input_arities.next().unwrap();
893 let (base_projection, base_with_project_stripped) =
894 if let MirRelationExpr::Project { input, outputs } = &**base {
895 (outputs.clone(), &**input)
896 } else {
897 // A input without a project is equivalent to an input
898 // with the project being all columns in the input in order.
899 ((0..arity).collect::<Vec<_>>(), &**base)
900 };
901 let mut result = Vec::new();
902 if let MirRelationExpr::Get {
903 id: first_id,
904 typ: _,
905 ..
906 } = base_with_project_stripped
907 {
908 if inputs.len() == 1 {
909 if let MirRelationExpr::Map { input, .. } = &inputs[0] {
910 if let MirRelationExpr::Union { base, inputs } = &**input {
911 if inputs.len() == 1 {
912 if let Some((input, outputs)) = base.is_negated_project() {
913 if let MirRelationExpr::Get {
914 id: second_id,
915 typ: _,
916 ..
917 } = input
918 {
919 if first_id == second_id {
920 result.extend(
921 input_keys
922 .next()
923 .unwrap()
924 .into_iter()
925 .filter(|key| {
926 key.iter().all(|c| {
927 outputs.get(*c) == Some(c)
928 && base_projection.get(*c)
929 == Some(c)
930 })
931 })
932 .cloned(),
933 );
934 }
935 }
936 }
937 }
938 }
939 }
940 }
941 }
942 // Important: do not inherit keys of either input, as not unique.
943 result
944 }
945 };
946 keys.sort();
947 keys.dedup();
948 keys
949 }
950
951 /// The number of columns in the relation.
952 ///
953 /// This number is determined from the type, which is determined recursively
954 /// at non-trivial cost.
955 ///
956 /// The arity is computed incrementally with a recursive post-order
957 /// traversal, that accumulates the arities for the relations yet to be
958 /// visited in `arity_stack`.
959 pub fn arity(&self) -> usize {
960 let mut arity_stack = Vec::new();
961 #[allow(deprecated)]
962 self.visit_pre_post_nolimit(
963 &mut |e: &MirRelationExpr| -> Option<Vec<&MirRelationExpr>> {
964 match &e {
965 MirRelationExpr::Let { body, .. } => {
966 // Do not traverse the value sub-graph, since it's not relevant for
967 // determining the arity of Let operators.
968 Some(vec![&*body])
969 }
970 MirRelationExpr::LetRec { body, .. } => {
971 // Do not traverse the value sub-graph, since it's not relevant for
972 // determining the arity of Let operators.
973 Some(vec![&*body])
974 }
975 MirRelationExpr::Project { .. } | MirRelationExpr::Reduce { .. } => {
976 // No further traversal is required; these operators know their arity.
977 Some(Vec::new())
978 }
979 _ => None,
980 }
981 },
982 &mut |e: &MirRelationExpr| {
983 match &e {
984 MirRelationExpr::Let { .. } => {
985 let body_arity = arity_stack.pop().unwrap();
986 arity_stack.push(0);
987 arity_stack.push(body_arity);
988 }
989 MirRelationExpr::LetRec { values, .. } => {
990 let body_arity = arity_stack.pop().unwrap();
991 arity_stack.extend(std::iter::repeat(0).take(values.len()));
992 arity_stack.push(body_arity);
993 }
994 MirRelationExpr::Project { .. } | MirRelationExpr::Reduce { .. } => {
995 arity_stack.push(0);
996 }
997 _ => {}
998 }
999 let num_inputs = e.num_inputs();
1000 let input_arities = arity_stack.drain(arity_stack.len() - num_inputs..);
1001 let arity = e.arity_with_input_arities(input_arities);
1002 arity_stack.push(arity);
1003 },
1004 );
1005 assert_eq!(arity_stack.len(), 1);
1006 arity_stack.pop().unwrap()
1007 }
1008
1009 /// Reports the arity of the relation given the schema of the input relations.
1010 ///
1011 /// `input_arities` is required to contain the arities for the input relations of
1012 /// the current relation in the same order as they are visited by `try_visit_children`
1013 /// method, even though not all may be used for computing the schema of the
1014 /// current relation. For example, `Let` expects two input types, one for the
1015 /// value relation and one for the body, in that order, but only the one for the
1016 /// body is used to determine the type of the `Let` relation.
1017 ///
1018 /// It is meant to be used during post-order traversals to compute arities
1019 /// incrementally.
1020 pub fn arity_with_input_arities<I>(&self, mut input_arities: I) -> usize
1021 where
1022 I: Iterator<Item = usize>,
1023 {
1024 use MirRelationExpr::*;
1025
1026 match self {
1027 Constant { rows: _, typ } => typ.arity(),
1028 Get { typ, .. } => typ.arity(),
1029 Let { .. } => {
1030 input_arities.next();
1031 input_arities.next().unwrap()
1032 }
1033 LetRec { values, .. } => {
1034 for _ in 0..values.len() {
1035 input_arities.next();
1036 }
1037 input_arities.next().unwrap()
1038 }
1039 Project { outputs, .. } => outputs.len(),
1040 Map { scalars, .. } => input_arities.next().unwrap() + scalars.len(),
1041 FlatMap { func, .. } => input_arities.next().unwrap() + func.output_arity(),
1042 Join { .. } => input_arities.sum(),
1043 Reduce {
1044 input: _,
1045 group_key,
1046 aggregates,
1047 ..
1048 } => group_key.len() + aggregates.len(),
1049 Filter { .. }
1050 | TopK { .. }
1051 | Negate { .. }
1052 | Threshold { .. }
1053 | Union { .. }
1054 | ArrangeBy { .. } => input_arities.next().unwrap(),
1055 }
1056 }
1057
1058 /// The number of child relations this relation has.
1059 pub fn num_inputs(&self) -> usize {
1060 let mut count = 0;
1061
1062 self.visit_children(|_| count += 1);
1063
1064 count
1065 }
1066
1067 /// Constructs a constant collection from specific rows and schema, where
1068 /// each row will have a multiplicity of one.
1069 pub fn constant(rows: Vec<Vec<Datum>>, typ: ReprRelationType) -> Self {
1070 let rows = rows.into_iter().map(|row| (row, Diff::ONE)).collect();
1071 MirRelationExpr::constant_diff(rows, typ)
1072 }
1073
1074 /// Constructs a constant collection from specific rows and schema, where
1075 /// each row can have an arbitrary multiplicity.
1076 pub fn constant_diff(rows: Vec<(Vec<Datum>, Diff)>, typ: ReprRelationType) -> Self {
1077 for (row, _diff) in &rows {
1078 for (datum, column_typ) in row.iter().zip_eq(typ.column_types.iter()) {
1079 assert!(
1080 datum.is_instance_of(column_typ),
1081 "Expected datum of type {:?}, got value {:?}",
1082 column_typ,
1083 datum
1084 );
1085 }
1086 }
1087 let rows = Ok(rows
1088 .into_iter()
1089 .map(move |(row, diff)| (Row::pack_slice(&row), diff))
1090 .collect());
1091 MirRelationExpr::Constant { rows, typ }
1092 }
1093
1094 /// If self is a constant, return the value and the type, otherwise `None`.
1095 /// Looks behind `ArrangeBy`s.
1096 pub fn as_const(&self) -> Option<(&Result<Vec<(Row, Diff)>, EvalError>, &ReprRelationType)> {
1097 match self {
1098 MirRelationExpr::Constant { rows, typ } => Some((rows, typ)),
1099 MirRelationExpr::ArrangeBy { input, .. } => input.as_const(),
1100 _ => None,
1101 }
1102 }
1103
1104 /// If self is a constant, mutably return the value and the type, otherwise `None`.
1105 /// Looks behind `ArrangeBy`s.
1106 pub fn as_const_mut(
1107 &mut self,
1108 ) -> Option<(
1109 &mut Result<Vec<(Row, Diff)>, EvalError>,
1110 &mut ReprRelationType,
1111 )> {
1112 match self {
1113 MirRelationExpr::Constant { rows, typ } => Some((rows, typ)),
1114 MirRelationExpr::ArrangeBy { input, .. } => input.as_const_mut(),
1115 _ => None,
1116 }
1117 }
1118
1119 /// If self is a constant error, return the error, otherwise `None`.
1120 /// Looks behind `ArrangeBy`s.
1121 pub fn as_const_err(&self) -> Option<&EvalError> {
1122 match self {
1123 MirRelationExpr::Constant { rows: Err(e), .. } => Some(e),
1124 MirRelationExpr::ArrangeBy { input, .. } => input.as_const_err(),
1125 _ => None,
1126 }
1127 }
1128
1129 /// Checks if `self` is the single element collection with no columns.
1130 pub fn is_constant_singleton(&self) -> bool {
1131 if let Some((Ok(rows), typ)) = self.as_const() {
1132 rows.len() == 1 && typ.column_types.len() == 0 && rows[0].1 == Diff::ONE
1133 } else {
1134 false
1135 }
1136 }
1137
1138 /// Constructs the expression for getting a local collection.
1139 pub fn local_get(id: LocalId, typ: ReprRelationType) -> Self {
1140 MirRelationExpr::Get {
1141 id: Id::Local(id),
1142 typ,
1143 access_strategy: AccessStrategy::UnknownOrLocal,
1144 }
1145 }
1146
1147 /// Constructs the expression for getting a global collection
1148 pub fn global_get(id: GlobalId, typ: ReprRelationType) -> Self {
1149 MirRelationExpr::Get {
1150 id: Id::Global(id),
1151 typ,
1152 access_strategy: AccessStrategy::UnknownOrLocal,
1153 }
1154 }
1155
1156 /// Retains only the columns specified by `output`.
1157 pub fn project(mut self, mut outputs: Vec<usize>) -> Self {
1158 if let MirRelationExpr::Project {
1159 outputs: columns, ..
1160 } = &mut self
1161 {
1162 // Update `outputs` to reference base columns of `input`.
1163 for column in outputs.iter_mut() {
1164 *column = columns[*column];
1165 }
1166 *columns = outputs;
1167 self
1168 } else {
1169 MirRelationExpr::Project {
1170 input: Box::new(self),
1171 outputs,
1172 }
1173 }
1174 }
1175
1176 /// Append to each row the results of applying elements of `scalar`.
1177 pub fn map(mut self, scalars: Vec<MirScalarExpr>) -> Self {
1178 if let MirRelationExpr::Map { scalars: s, .. } = &mut self {
1179 s.extend(scalars);
1180 self
1181 } else if !scalars.is_empty() {
1182 MirRelationExpr::Map {
1183 input: Box::new(self),
1184 scalars,
1185 }
1186 } else {
1187 self
1188 }
1189 }
1190
1191 /// Append to each row a single `scalar`.
1192 pub fn map_one(self, scalar: MirScalarExpr) -> Self {
1193 self.map(vec![scalar])
1194 }
1195
1196 /// Like `map`, but yields zero-or-more output rows per input row
1197 pub fn flat_map(self, func: TableFunc, exprs: Vec<MirScalarExpr>) -> Self {
1198 MirRelationExpr::FlatMap {
1199 input: Box::new(self),
1200 func,
1201 exprs,
1202 }
1203 }
1204
1205 /// Retain only the rows satisfying each of several predicates.
1206 pub fn filter<I>(mut self, predicates: I) -> Self
1207 where
1208 I: IntoIterator<Item = MirScalarExpr>,
1209 {
1210 // Extract existing predicates
1211 let mut new_predicates = if let MirRelationExpr::Filter { input, predicates } = self {
1212 self = *input;
1213 predicates
1214 } else {
1215 Vec::new()
1216 };
1217 // Normalize collection of predicates.
1218 new_predicates.extend(predicates);
1219 new_predicates.retain(|p| !p.is_literal_true());
1220 new_predicates.sort();
1221 new_predicates.dedup();
1222 // Introduce a `Filter` only if we have predicates.
1223 if !new_predicates.is_empty() {
1224 self = MirRelationExpr::Filter {
1225 input: Box::new(self),
1226 predicates: new_predicates,
1227 };
1228 }
1229
1230 self
1231 }
1232
1233 /// Form the Cartesian outer-product of rows in both inputs.
1234 pub fn product(mut self, right: Self) -> Self {
1235 if right.is_constant_singleton() {
1236 self
1237 } else if self.is_constant_singleton() {
1238 right
1239 } else if let MirRelationExpr::Join { inputs, .. } = &mut self {
1240 inputs.push(right);
1241 self
1242 } else {
1243 MirRelationExpr::join(vec![self, right], vec![])
1244 }
1245 }
1246
1247 /// Performs a relational equijoin among the input collections.
1248 ///
1249 /// The sequence `inputs` each describe different input collections, and the sequence `variables` describes
1250 /// equality constraints that some of their columns must satisfy. Each element in `variable` describes a set
1251 /// of pairs `(input_index, column_index)` where every value described by that set must be equal.
1252 ///
1253 /// For example, the pair `(input, column)` indexes into `inputs[input][column]`, extracting the `input`th
1254 /// input collection and for each row examining its `column`th column.
1255 ///
1256 /// # Example
1257 ///
1258 /// ```rust
1259 /// use mz_repr::{Datum, SqlColumnType, ReprRelationType, ReprScalarType};
1260 /// use mz_expr::MirRelationExpr;
1261 ///
1262 /// // A common schema for each input.
1263 /// let schema = ReprRelationType::new(vec![
1264 /// ReprScalarType::Int32.nullable(false),
1265 /// ReprScalarType::Int32.nullable(false),
1266 /// ]);
1267 ///
1268 /// // the specific data are not important here.
1269 /// let data = vec![Datum::Int32(0), Datum::Int32(1)];
1270 ///
1271 /// // Three collections that could have been different.
1272 /// let input0 = MirRelationExpr::constant(vec![data.clone()], schema.clone());
1273 /// let input1 = MirRelationExpr::constant(vec![data.clone()], schema.clone());
1274 /// let input2 = MirRelationExpr::constant(vec![data.clone()], schema.clone());
1275 ///
1276 /// // Join the three relations looking for triangles, like so.
1277 /// //
1278 /// // Output(A,B,C) := Input0(A,B), Input1(B,C), Input2(A,C)
1279 /// let joined = MirRelationExpr::join(
1280 /// vec![input0, input1, input2],
1281 /// vec![
1282 /// vec![(0,0), (2,0)], // fields A of inputs 0 and 2.
1283 /// vec![(0,1), (1,0)], // fields B of inputs 0 and 1.
1284 /// vec![(1,1), (2,1)], // fields C of inputs 1 and 2.
1285 /// ],
1286 /// );
1287 ///
1288 /// // Technically the above produces `Output(A,B,B,C,A,C)` because the columns are concatenated.
1289 /// // A projection resolves this and produces the correct output.
1290 /// let result = joined.project(vec![0, 1, 3]);
1291 /// ```
1292 pub fn join(inputs: Vec<MirRelationExpr>, variables: Vec<Vec<(usize, usize)>>) -> Self {
1293 let input_mapper = join_input_mapper::JoinInputMapper::new(&inputs);
1294
1295 let equivalences = variables
1296 .into_iter()
1297 .map(|vs| {
1298 vs.into_iter()
1299 .map(|(r, c)| input_mapper.map_expr_to_global(MirScalarExpr::column(c), r))
1300 .collect::<Vec<_>>()
1301 })
1302 .collect::<Vec<_>>();
1303
1304 Self::join_scalars(inputs, equivalences)
1305 }
1306
1307 /// Constructs a join operator from inputs and required-equal scalar expressions.
1308 pub fn join_scalars(
1309 mut inputs: Vec<MirRelationExpr>,
1310 equivalences: Vec<Vec<MirScalarExpr>>,
1311 ) -> Self {
1312 // Remove all constant inputs that are the identity for join.
1313 // They neither introduce nor modify any column references.
1314 inputs.retain(|i| !i.is_constant_singleton());
1315 MirRelationExpr::Join {
1316 inputs,
1317 equivalences,
1318 implementation: JoinImplementation::Unimplemented,
1319 }
1320 }
1321
1322 /// Perform a key-wise reduction / aggregation.
1323 ///
1324 /// The `group_key` argument indicates columns in the input collection that should
1325 /// be grouped, and `aggregates` lists aggregation functions each of which produces
1326 /// one output column in addition to the keys.
1327 pub fn reduce(
1328 self,
1329 group_key: Vec<usize>,
1330 aggregates: Vec<AggregateExpr>,
1331 expected_group_size: Option<u64>,
1332 ) -> Self {
1333 MirRelationExpr::Reduce {
1334 input: Box::new(self),
1335 group_key: group_key.into_iter().map(MirScalarExpr::column).collect(),
1336 aggregates,
1337 monotonic: false,
1338 expected_group_size,
1339 }
1340 }
1341
1342 /// Perform a key-wise reduction order by and limit.
1343 ///
1344 /// The `group_key` argument indicates columns in the input collection that should
1345 /// be grouped, the `order_key` argument indicates columns that should be further
1346 /// used to order records within groups, and the `limit` argument constrains the
1347 /// total number of records that should be produced in each group.
1348 pub fn top_k(
1349 self,
1350 group_key: Vec<usize>,
1351 order_key: Vec<ColumnOrder>,
1352 limit: Option<MirScalarExpr>,
1353 offset: usize,
1354 expected_group_size: Option<u64>,
1355 ) -> Self {
1356 MirRelationExpr::TopK {
1357 input: Box::new(self),
1358 group_key,
1359 order_key,
1360 limit,
1361 offset,
1362 expected_group_size,
1363 monotonic: false,
1364 }
1365 }
1366
1367 /// Negates the occurrences of each row.
1368 pub fn negate(self) -> Self {
1369 if let MirRelationExpr::Negate { input } = self {
1370 *input
1371 } else {
1372 MirRelationExpr::Negate {
1373 input: Box::new(self),
1374 }
1375 }
1376 }
1377
1378 /// Removes all but the first occurrence of each row.
1379 pub fn distinct(self) -> Self {
1380 let arity = self.arity();
1381 self.distinct_by((0..arity).collect())
1382 }
1383
1384 /// Removes all but the first occurrence of each key. Columns not included
1385 /// in the `group_key` are discarded.
1386 pub fn distinct_by(self, group_key: Vec<usize>) -> Self {
1387 self.reduce(group_key, vec![], None)
1388 }
1389
1390 /// Discards rows with a negative frequency.
1391 pub fn threshold(self) -> Self {
1392 if let MirRelationExpr::Threshold { .. } = &self {
1393 self
1394 } else {
1395 MirRelationExpr::Threshold {
1396 input: Box::new(self),
1397 }
1398 }
1399 }
1400
1401 /// Unions together any number inputs.
1402 ///
1403 /// If `inputs` is empty, then an empty relation of type `typ` is
1404 /// constructed.
1405 pub fn union_many(mut inputs: Vec<Self>, typ: ReprRelationType) -> Self {
1406 // Deconstruct `inputs` as `Union`s and reconstitute.
1407 let mut flat_inputs = Vec::with_capacity(inputs.len());
1408 for input in inputs {
1409 if let MirRelationExpr::Union { base, inputs } = input {
1410 flat_inputs.push(*base);
1411 flat_inputs.extend(inputs);
1412 } else {
1413 flat_inputs.push(input);
1414 }
1415 }
1416 inputs = flat_inputs;
1417 if inputs.len() == 0 {
1418 MirRelationExpr::Constant {
1419 rows: Ok(vec![]),
1420 typ,
1421 }
1422 } else if inputs.len() == 1 {
1423 inputs.into_element()
1424 } else {
1425 MirRelationExpr::Union {
1426 base: Box::new(inputs.remove(0)),
1427 inputs,
1428 }
1429 }
1430 }
1431
1432 /// Produces one collection where each row is present with the sum of its frequencies in each input.
1433 pub fn union(self, other: Self) -> Self {
1434 // Deconstruct `self` and `other` as `Union`s and reconstitute.
1435 let mut flat_inputs = Vec::with_capacity(2);
1436 if let MirRelationExpr::Union { base, inputs } = self {
1437 flat_inputs.push(*base);
1438 flat_inputs.extend(inputs);
1439 } else {
1440 flat_inputs.push(self);
1441 }
1442 if let MirRelationExpr::Union { base, inputs } = other {
1443 flat_inputs.push(*base);
1444 flat_inputs.extend(inputs);
1445 } else {
1446 flat_inputs.push(other);
1447 }
1448
1449 MirRelationExpr::Union {
1450 base: Box::new(flat_inputs.remove(0)),
1451 inputs: flat_inputs,
1452 }
1453 }
1454
1455 /// Arranges the collection by the specified columns
1456 pub fn arrange_by(self, keys: &[Vec<MirScalarExpr>]) -> Self {
1457 MirRelationExpr::ArrangeBy {
1458 input: Box::new(self),
1459 keys: keys.to_owned(),
1460 }
1461 }
1462
1463 /// Indicates if this is a constant empty collection.
1464 ///
1465 /// A false value does not mean the collection is known to be non-empty,
1466 /// only that we cannot currently determine that it is statically empty.
1467 pub fn is_empty(&self) -> bool {
1468 if let Some((Ok(rows), ..)) = self.as_const() {
1469 rows.is_empty()
1470 } else {
1471 false
1472 }
1473 }
1474
1475 /// If the expression is a negated project, return the input and the projection.
1476 pub fn is_negated_project(&self) -> Option<(&MirRelationExpr, &[usize])> {
1477 if let MirRelationExpr::Negate { input } = self {
1478 if let MirRelationExpr::Project { input, outputs } = &**input {
1479 return Some((&**input, outputs));
1480 }
1481 }
1482 if let MirRelationExpr::Project { input, outputs } = self {
1483 if let MirRelationExpr::Negate { input } = &**input {
1484 return Some((&**input, outputs));
1485 }
1486 }
1487 None
1488 }
1489
1490 /// Pretty-print this [MirRelationExpr] to a string.
1491 pub fn pretty(&self) -> String {
1492 let config = ExplainConfig::default();
1493 self.debug_explain(&config, None)
1494 }
1495
1496 /// Pretty-print this [MirRelationExpr] to a string using a custom
1497 /// [ExplainConfig] and an optionally provided [ExprHumanizer].
1498 /// This is intended for debugging and tests, not users.
1499 pub fn debug_explain(
1500 &self,
1501 config: &ExplainConfig,
1502 humanizer: Option<&dyn ExprHumanizer>,
1503 ) -> String {
1504 text_string_at(self, || PlanRenderingContext {
1505 indent: Indent::default(),
1506 humanizer: humanizer.unwrap_or(&DummyHumanizer),
1507 annotations: BTreeMap::default(),
1508 config,
1509 ambiguous_ids: BTreeSet::default(),
1510 })
1511 }
1512
1513 /// Take ownership of `self`, leaving an empty `MirRelationExpr::Constant` with the optionally
1514 /// given scalar types. The given scalar types should be `base_eq` with the types that `typ()`
1515 /// would find. Keys and nullability are ignored in the given `SqlRelationType`, and instead we set
1516 /// the best possible key and nullability, since we are making an empty collection.
1517 ///
1518 /// If `typ` is not given, then this calls `.typ()` (which is possibly expensive) to determine
1519 /// the correct type.
1520 pub fn take_safely(&mut self, typ: Option<ReprRelationType>) -> MirRelationExpr {
1521 if let Some(typ) = &typ {
1522 let self_typ = self.typ();
1523 soft_assert_no_log!(
1524 self_typ
1525 .column_types
1526 .iter()
1527 .zip_eq(typ.column_types.iter())
1528 .all(|(t1, t2)| t1.scalar_type == t2.scalar_type)
1529 );
1530 }
1531 let mut typ = typ.unwrap_or_else(|| self.typ());
1532 typ.keys = vec![vec![]];
1533 for ct in typ.column_types.iter_mut() {
1534 ct.nullable = false;
1535 }
1536 std::mem::replace(
1537 self,
1538 MirRelationExpr::Constant {
1539 rows: Ok(vec![]),
1540 typ,
1541 },
1542 )
1543 }
1544
1545 /// Take ownership of `self`, leaving an empty `MirRelationExpr::Constant` with the given scalar
1546 /// types. Nullability is ignored in the given `SqlColumnType`s, and instead we set the best
1547 /// possible nullability, since we are making an empty collection.
1548 pub fn take_safely_with_sql_col_types(&mut self, typ: Vec<SqlColumnType>) -> MirRelationExpr {
1549 self.take_safely(Some(ReprRelationType::from(&SqlRelationType::new(typ))))
1550 }
1551
1552 /// Like [`Self::take_safely_with_col_types`], but accepts `Vec<ReprColumnType>`.
1553 ///
1554 /// This is the preferred entry point for optimizer transforms, where repr
1555 /// types are the native currency. Internally converts to [`SqlColumnType`]
1556 /// and delegates to [`Self::take_safely_with_col_types`].
1557 pub fn take_safely_with_col_types(&mut self, typ: Vec<ReprColumnType>) -> MirRelationExpr {
1558 self.take_safely(Some(ReprRelationType::new(typ)))
1559 }
1560
1561 /// Take ownership of `self`, leaving an empty `MirRelationExpr::Constant` with an **incorrect** type.
1562 ///
1563 /// This should only be used if `self` is about to be dropped or otherwise overwritten.
1564 pub fn take_dangerous(&mut self) -> MirRelationExpr {
1565 let empty = MirRelationExpr::Constant {
1566 rows: Ok(vec![]),
1567 typ: ReprRelationType::new(Vec::new()),
1568 };
1569 std::mem::replace(self, empty)
1570 }
1571
1572 /// Replaces `self` with some logic applied to `self`.
1573 pub fn replace_using<F>(&mut self, logic: F)
1574 where
1575 F: FnOnce(MirRelationExpr) -> MirRelationExpr,
1576 {
1577 let empty = MirRelationExpr::Constant {
1578 rows: Ok(vec![]),
1579 typ: ReprRelationType::new(Vec::new()),
1580 };
1581 let expr = std::mem::replace(self, empty);
1582 *self = logic(expr);
1583 }
1584
1585 /// Store `self` in a `Let` and pass the corresponding `Get` to `body`.
1586 pub fn let_in<Body, E>(self, id_gen: &mut IdGen, body: Body) -> Result<MirRelationExpr, E>
1587 where
1588 Body: FnOnce(&mut IdGen, MirRelationExpr) -> Result<MirRelationExpr, E>,
1589 {
1590 if let MirRelationExpr::Get { .. } = self {
1591 // already done
1592 body(id_gen, self)
1593 } else {
1594 let id = LocalId::new(id_gen.allocate_id());
1595 let get = MirRelationExpr::Get {
1596 id: Id::Local(id),
1597 typ: self.typ(),
1598 access_strategy: AccessStrategy::UnknownOrLocal,
1599 };
1600 let body = (body)(id_gen, get)?;
1601 Ok(MirRelationExpr::Let {
1602 id,
1603 value: Box::new(self),
1604 body: Box::new(body),
1605 })
1606 }
1607 }
1608
1609 /// Return every row in `self` that does not have a matching row in the first columns of `keys_and_values`, using `default` to fill in the remaining columns
1610 /// (If `default` is a row of nulls, this is the 'outer' part of LEFT OUTER JOIN)
1611 pub fn anti_lookup<E>(
1612 self,
1613 id_gen: &mut IdGen,
1614 keys_and_values: MirRelationExpr,
1615 default: Vec<(Datum, ReprScalarType)>,
1616 ) -> Result<MirRelationExpr, E> {
1617 let (data, column_types): (Vec<_>, Vec<_>) = default
1618 .into_iter()
1619 .map(|(datum, scalar_type)| {
1620 (
1621 datum,
1622 ReprColumnType {
1623 scalar_type,
1624 nullable: datum.is_null(),
1625 },
1626 )
1627 })
1628 .unzip();
1629 assert_eq!(keys_and_values.arity() - self.arity(), data.len());
1630 self.let_in(id_gen, |_id_gen, get_keys| {
1631 let get_keys_arity = get_keys.arity();
1632 Ok(MirRelationExpr::join(
1633 vec![
1634 // all the missing keys (with count 1)
1635 keys_and_values
1636 .distinct_by((0..get_keys_arity).collect())
1637 .negate()
1638 .union(get_keys.clone().distinct()),
1639 // join with keys to get the correct counts
1640 get_keys.clone(),
1641 ],
1642 (0..get_keys_arity).map(|i| vec![(0, i), (1, i)]).collect(),
1643 )
1644 // get rid of the extra copies of columns from keys
1645 .project((0..get_keys_arity).collect())
1646 // This join is logically equivalent to
1647 // `.map(<default_expr>)`, but using a join allows for
1648 // potential predicate pushdown and elision in the
1649 // optimizer.
1650 .product(MirRelationExpr::constant(
1651 vec![data],
1652 ReprRelationType::new(column_types),
1653 )))
1654 })
1655 }
1656
1657 /// Return:
1658 /// * every row in keys_and_values
1659 /// * every row in `self` that does not have a matching row in the first columns of
1660 /// `keys_and_values`, using `default` to fill in the remaining columns
1661 /// (This is LEFT OUTER JOIN if:
1662 /// 1) `default` is a row of null
1663 /// 2) matching rows in `keys_and_values` and `self` have the same multiplicity.)
1664 pub fn lookup<E>(
1665 self,
1666 id_gen: &mut IdGen,
1667 keys_and_values: MirRelationExpr,
1668 default: Vec<(Datum<'static>, ReprScalarType)>,
1669 ) -> Result<MirRelationExpr, E> {
1670 keys_and_values.let_in(id_gen, |id_gen, get_keys_and_values| {
1671 Ok(get_keys_and_values.clone().union(self.anti_lookup(
1672 id_gen,
1673 get_keys_and_values,
1674 default,
1675 )?))
1676 })
1677 }
1678
1679 /// True iff the expression contains a `NullaryFunc::MzLogicalTimestamp`.
1680 pub fn contains_temporal(&self) -> bool {
1681 let mut contains = false;
1682 self.visit_scalars(&mut |e| contains = contains || e.contains_temporal());
1683 contains
1684 }
1685
1686 /// Fallible visitor for the [`MirScalarExpr`]s directly owned by this relation expression.
1687 ///
1688 /// The `f` visitor should not recursively descend into owned [`MirRelationExpr`]s.
1689 pub fn try_visit_scalars_mut1<F, E>(&mut self, f: &mut F) -> Result<(), E>
1690 where
1691 F: FnMut(&mut MirScalarExpr) -> Result<(), E>,
1692 {
1693 use MirRelationExpr::*;
1694 match self {
1695 Map { scalars, .. } => {
1696 for s in scalars {
1697 f(s)?;
1698 }
1699 }
1700 Filter { predicates, .. } => {
1701 for p in predicates {
1702 f(p)?;
1703 }
1704 }
1705 FlatMap { exprs, .. } => {
1706 for expr in exprs {
1707 f(expr)?;
1708 }
1709 }
1710 Join {
1711 inputs: _,
1712 equivalences,
1713 implementation,
1714 } => {
1715 for equivalence in equivalences {
1716 for expr in equivalence {
1717 f(expr)?;
1718 }
1719 }
1720 match implementation {
1721 JoinImplementation::Differential((_, start_key, _), order) => {
1722 if let Some(start_key) = start_key {
1723 for k in start_key {
1724 f(k)?;
1725 }
1726 }
1727 for (_, lookup_key, _) in order {
1728 for k in lookup_key {
1729 f(k)?;
1730 }
1731 }
1732 }
1733 JoinImplementation::DeltaQuery(paths) => {
1734 for path in paths {
1735 for (_, lookup_key, _) in path {
1736 for k in lookup_key {
1737 f(k)?;
1738 }
1739 }
1740 }
1741 }
1742 JoinImplementation::IndexedFilter(_coll_id, _idx_id, index_key, _) => {
1743 for k in index_key {
1744 f(k)?;
1745 }
1746 }
1747 JoinImplementation::Unimplemented => {} // No scalar exprs
1748 }
1749 }
1750 ArrangeBy { keys, .. } => {
1751 for key in keys {
1752 for s in key {
1753 f(s)?;
1754 }
1755 }
1756 }
1757 Reduce {
1758 group_key,
1759 aggregates,
1760 ..
1761 } => {
1762 for s in group_key {
1763 f(s)?;
1764 }
1765 for agg in aggregates {
1766 f(&mut agg.expr)?;
1767 }
1768 }
1769 TopK { limit, .. } => {
1770 if let Some(s) = limit {
1771 f(s)?;
1772 }
1773 }
1774 Constant { .. }
1775 | Get { .. }
1776 | Let { .. }
1777 | LetRec { .. }
1778 | Project { .. }
1779 | Negate { .. }
1780 | Threshold { .. }
1781 | Union { .. } => (),
1782 }
1783 Ok(())
1784 }
1785
1786 /// Fallible mutable visitor for the [`MirScalarExpr`]s in the [`MirRelationExpr`] subtree
1787 /// rooted at `self`.
1788 ///
1789 /// Note that this does not recurse into [`MirRelationExpr`] subtrees within [`MirScalarExpr`]
1790 /// nodes.
1791 pub fn try_visit_scalars_mut<F, E>(&mut self, f: &mut F) -> Result<(), E>
1792 where
1793 F: FnMut(&mut MirScalarExpr) -> Result<(), E>,
1794 E: From<RecursionLimitError>,
1795 {
1796 self.try_visit_mut_post(&mut |expr| expr.try_visit_scalars_mut1(f))
1797 }
1798
1799 /// Infallible mutable visitor for the [`MirScalarExpr`]s in the [`MirRelationExpr`] subtree
1800 /// rooted at `self`.
1801 ///
1802 /// Note that this does not recurse into [`MirRelationExpr`] subtrees within [`MirScalarExpr`]
1803 /// nodes.
1804 pub fn visit_scalars_mut<F>(&mut self, f: &mut F)
1805 where
1806 F: FnMut(&mut MirScalarExpr),
1807 {
1808 self.try_visit_scalars_mut(&mut |s| {
1809 f(s);
1810 Ok::<_, RecursionLimitError>(())
1811 })
1812 .expect("Unexpected error in `visit_scalars_mut` call");
1813 }
1814
1815 /// Fallible visitor for the [`MirScalarExpr`]s directly owned by this relation expression.
1816 ///
1817 /// The `f` visitor should not recursively descend into owned [`MirRelationExpr`]s.
1818 pub fn try_visit_scalars_1<F, E>(&self, f: &mut F) -> Result<(), E>
1819 where
1820 F: FnMut(&MirScalarExpr) -> Result<(), E>,
1821 {
1822 use MirRelationExpr::*;
1823 match self {
1824 Map { scalars, .. } => {
1825 for s in scalars {
1826 f(s)?;
1827 }
1828 }
1829 Filter { predicates, .. } => {
1830 for p in predicates {
1831 f(p)?;
1832 }
1833 }
1834 FlatMap { exprs, .. } => {
1835 for expr in exprs {
1836 f(expr)?;
1837 }
1838 }
1839 Join {
1840 inputs: _,
1841 equivalences,
1842 implementation,
1843 } => {
1844 for equivalence in equivalences {
1845 for expr in equivalence {
1846 f(expr)?;
1847 }
1848 }
1849 match implementation {
1850 JoinImplementation::Differential((_, start_key, _), order) => {
1851 if let Some(start_key) = start_key {
1852 for k in start_key {
1853 f(k)?;
1854 }
1855 }
1856 for (_, lookup_key, _) in order {
1857 for k in lookup_key {
1858 f(k)?;
1859 }
1860 }
1861 }
1862 JoinImplementation::DeltaQuery(paths) => {
1863 for path in paths {
1864 for (_, lookup_key, _) in path {
1865 for k in lookup_key {
1866 f(k)?;
1867 }
1868 }
1869 }
1870 }
1871 JoinImplementation::IndexedFilter(_coll_id, _idx_id, index_key, _) => {
1872 for k in index_key {
1873 f(k)?;
1874 }
1875 }
1876 JoinImplementation::Unimplemented => {} // No scalar exprs
1877 }
1878 }
1879 ArrangeBy { keys, .. } => {
1880 for key in keys {
1881 for s in key {
1882 f(s)?;
1883 }
1884 }
1885 }
1886 Reduce {
1887 group_key,
1888 aggregates,
1889 ..
1890 } => {
1891 for s in group_key {
1892 f(s)?;
1893 }
1894 for agg in aggregates {
1895 f(&agg.expr)?;
1896 }
1897 }
1898 TopK { limit, .. } => {
1899 if let Some(s) = limit {
1900 f(s)?;
1901 }
1902 }
1903 Constant { .. }
1904 | Get { .. }
1905 | Let { .. }
1906 | LetRec { .. }
1907 | Project { .. }
1908 | Negate { .. }
1909 | Threshold { .. }
1910 | Union { .. } => (),
1911 }
1912 Ok(())
1913 }
1914
1915 /// Fallible immutable visitor for the [`MirScalarExpr`]s in the [`MirRelationExpr`] subtree
1916 /// rooted at `self`.
1917 ///
1918 /// Note that this does not recurse into [`MirRelationExpr`] subtrees within [`MirScalarExpr`]
1919 /// nodes.
1920 pub fn try_visit_scalars<F, E>(&self, f: &mut F) -> Result<(), E>
1921 where
1922 F: FnMut(&MirScalarExpr) -> Result<(), E>,
1923 E: From<RecursionLimitError>,
1924 {
1925 self.try_visit_post(&mut |expr| expr.try_visit_scalars_1(f))
1926 }
1927
1928 /// Infallible immutable visitor for the [`MirScalarExpr`]s in the [`MirRelationExpr`] subtree
1929 /// rooted at `self`.
1930 ///
1931 /// Note that this does not recurse into [`MirRelationExpr`] subtrees within [`MirScalarExpr`]
1932 /// nodes.
1933 pub fn visit_scalars<F>(&self, f: &mut F)
1934 where
1935 F: FnMut(&MirScalarExpr),
1936 {
1937 self.try_visit_scalars(&mut |s| {
1938 f(s);
1939 Ok::<_, RecursionLimitError>(())
1940 })
1941 .expect("Unexpected error in `visit_scalars` call");
1942 }
1943
1944 /// Clears the contents of `self` even if it's so deep that simply dropping it would cause a
1945 /// stack overflow in `drop_in_place`.
1946 ///
1947 /// Leaves `self` in an unusable state, so this should only be used if `self` is about to be
1948 /// dropped or otherwise overwritten.
1949 pub fn destroy_carefully(&mut self) {
1950 let mut todo = vec![self.take_dangerous()];
1951 while let Some(mut expr) = todo.pop() {
1952 for child in expr.children_mut() {
1953 todo.push(child.take_dangerous());
1954 }
1955 }
1956 }
1957
1958 /// Computes the size (total number of nodes) and maximum depth of a MirRelationExpr for
1959 /// debug printing purposes.
1960 pub fn debug_size_and_depth(&self) -> (usize, usize) {
1961 let mut size = 0;
1962 let mut max_depth = 0;
1963 let mut todo = vec![(self, 1)];
1964 while let Some((expr, depth)) = todo.pop() {
1965 size += 1;
1966 max_depth = max(max_depth, depth);
1967 todo.extend(expr.children().map(|c| (c, depth + 1)));
1968 }
1969 (size, max_depth)
1970 }
1971
1972 /// The MirRelationExpr is considered potentially expensive if and only if
1973 /// at least one of the following conditions is true:
1974 ///
1975 /// - It contains at least one FlatMap or a Reduce operator.
1976 /// - It contains at least one MirScalarExpr with a function call.
1977 ///
1978 /// !!!WARNING!!!: this method has an HirRelationExpr counterpart. The two
1979 /// should be kept in sync w.r.t. HIR ⇒ MIR lowering!
1980 pub fn could_run_expensive_function(&self) -> bool {
1981 let mut result = false;
1982 self.visit_pre(|e: &MirRelationExpr| {
1983 use MirRelationExpr::*;
1984 use MirScalarExpr::*;
1985 if let Err(_) = self.try_visit_scalars::<_, RecursionLimitError>(&mut |scalar| {
1986 result |= match scalar {
1987 Column(_, _) | Literal(_, _) | CallUnmaterializable(_) | If { .. } => false,
1988 // Function calls are considered expensive
1989 CallUnary { .. } | CallBinary { .. } | CallVariadic { .. } => true,
1990 };
1991 Ok(())
1992 }) {
1993 // Conservatively set `true` if on RecursionLimitError.
1994 result = true;
1995 }
1996 // FlatMap has a table function; Reduce has an aggregate function.
1997 // Other constructs use MirScalarExpr to run a function
1998 result |= matches!(e, FlatMap { .. } | Reduce { .. });
1999 });
2000 result
2001 }
2002
2003 /// Hash to an u64 using Rust's default Hasher. (Which is a somewhat slower, but better Hasher
2004 /// than what `Hashable::hashed` would give us.)
2005 pub fn hash_to_u64(&self) -> u64 {
2006 let mut h = DefaultHasher::new();
2007 self.hash(&mut h);
2008 h.finish()
2009 }
2010}
2011
2012// `LetRec` helpers
2013impl MirRelationExpr {
2014 /// True when `expr` contains a `LetRec` AST node.
2015 pub fn is_recursive(self: &MirRelationExpr) -> bool {
2016 let mut worklist = vec![self];
2017 while let Some(expr) = worklist.pop() {
2018 if let MirRelationExpr::LetRec { .. } = expr {
2019 return true;
2020 }
2021 worklist.extend(expr.children());
2022 }
2023 false
2024 }
2025
2026 /// Return the number of sub-expressions in the tree (including self).
2027 pub fn size(&self) -> usize {
2028 let mut size = 0;
2029 self.visit_pre(|_| size += 1);
2030 size
2031 }
2032
2033 /// Given the ids and values of a LetRec, it computes the subset of ids that are used across
2034 /// iterations. These are those ids that have a reference before they are defined, when reading
2035 /// all the bindings in order.
2036 ///
2037 /// For example:
2038 /// ```SQL
2039 /// WITH MUTUALLY RECURSIVE
2040 /// x(...) AS f(z),
2041 /// y(...) AS g(x),
2042 /// z(...) AS h(y)
2043 /// ...;
2044 /// ```
2045 /// Here, only `z` is returned, because `x` and `y` are referenced only within the same
2046 /// iteration.
2047 ///
2048 /// Note that if a binding references itself, that is also returned.
2049 pub fn recursive_ids(ids: &[LocalId], values: &[MirRelationExpr]) -> BTreeSet<LocalId> {
2050 let mut used_across_iterations = BTreeSet::new();
2051 let mut defined = BTreeSet::new();
2052 for (binding_id, value) in itertools::zip_eq(ids.iter(), values.iter()) {
2053 value.visit_pre(|expr| {
2054 if let MirRelationExpr::Get {
2055 id: Local(get_id), ..
2056 } = expr
2057 {
2058 // If we haven't seen a definition for it yet, then this will refer
2059 // to the previous iteration.
2060 // The `ids.contains` part of the condition is needed to exclude
2061 // those ids that are not really in this LetRec, but either an inner
2062 // or outer one.
2063 if !defined.contains(get_id) && ids.contains(get_id) {
2064 used_across_iterations.insert(*get_id);
2065 }
2066 }
2067 });
2068 defined.insert(*binding_id);
2069 }
2070 used_across_iterations
2071 }
2072
2073 /// Replaces `LetRec` nodes with a stack of `Let` nodes.
2074 ///
2075 /// In each `Let` binding, uses of `Get` in `value` that are not at strictly greater
2076 /// identifiers are rewritten to be the constant collection.
2077 /// This makes the computation perform exactly "one" iteration.
2078 ///
2079 /// This was used only temporarily while developing `LetRec`.
2080 pub fn make_nonrecursive(self: &mut MirRelationExpr) {
2081 let mut deadlist = BTreeSet::new();
2082 let mut worklist = vec![self];
2083 while let Some(expr) = worklist.pop() {
2084 if let MirRelationExpr::LetRec {
2085 ids,
2086 values,
2087 limits: _,
2088 body,
2089 } = expr
2090 {
2091 let ids_values = values
2092 .drain(..)
2093 .zip_eq(ids)
2094 .map(|(value, id)| (*id, value))
2095 .collect::<Vec<_>>();
2096 *expr = body.take_dangerous();
2097 for (id, mut value) in ids_values.into_iter().rev() {
2098 // Remove references to potentially recursive identifiers.
2099 deadlist.insert(id);
2100 value.visit_pre_mut(|e| {
2101 if let MirRelationExpr::Get {
2102 id: crate::Id::Local(id),
2103 typ,
2104 ..
2105 } = e
2106 {
2107 let typ = typ.clone();
2108 if deadlist.contains(id) {
2109 e.take_safely(Some(typ));
2110 }
2111 }
2112 });
2113 *expr = MirRelationExpr::Let {
2114 id,
2115 value: Box::new(value),
2116 body: Box::new(expr.take_dangerous()),
2117 };
2118 }
2119 worklist.push(expr);
2120 } else {
2121 worklist.extend(expr.children_mut().rev());
2122 }
2123 }
2124 }
2125
2126 /// For each Id `id'` referenced in `expr`, if it is larger or equal than `id`, then record in
2127 /// `expire_whens` that when `id'` is redefined, then we should expire the information that
2128 /// we are holding about `id`. Call `do_expirations` with `expire_whens` at each Id
2129 /// redefinition.
2130 ///
2131 /// IMPORTANT: Relies on the numbering of Ids to be what `renumber_bindings` gives.
2132 pub fn collect_expirations(
2133 id: LocalId,
2134 expr: &MirRelationExpr,
2135 expire_whens: &mut BTreeMap<LocalId, Vec<LocalId>>,
2136 ) {
2137 expr.visit_pre(|e| {
2138 if let MirRelationExpr::Get {
2139 id: Id::Local(referenced_id),
2140 ..
2141 } = e
2142 {
2143 // The following check needs `renumber_bindings` to have run recently
2144 if referenced_id >= &id {
2145 expire_whens
2146 .entry(*referenced_id)
2147 .or_insert_with(Vec::new)
2148 .push(id);
2149 }
2150 }
2151 });
2152 }
2153
2154 /// Call this function when `id` is redefined. It modifies `id_infos` by removing information
2155 /// about such Ids whose information depended on the earlier definition of `id`, according to
2156 /// `expire_whens`. Also modifies `expire_whens`: it removes the currently processed entry.
2157 pub fn do_expirations<I>(
2158 redefined_id: LocalId,
2159 expire_whens: &mut BTreeMap<LocalId, Vec<LocalId>>,
2160 id_infos: &mut BTreeMap<LocalId, I>,
2161 ) -> Vec<(LocalId, I)> {
2162 let mut expired_infos = Vec::new();
2163 if let Some(expirations) = expire_whens.remove(&redefined_id) {
2164 for expired_id in expirations.into_iter() {
2165 if let Some(offer) = id_infos.remove(&expired_id) {
2166 expired_infos.push((expired_id, offer));
2167 }
2168 }
2169 }
2170 expired_infos
2171 }
2172}
2173/// Augment non-nullability of columns, by observing either
2174/// 1. Predicates that explicitly test for null values, and
2175/// 2. Columns that if null would make a predicate be null.
2176pub fn non_nullable_columns(predicates: &[MirScalarExpr]) -> BTreeSet<usize> {
2177 let mut nonnull_required_columns = BTreeSet::new();
2178 for predicate in predicates {
2179 // Add any columns that being null would force the predicate to be null.
2180 // Should that happen, the row would be discarded.
2181 predicate.non_null_requirements(&mut nonnull_required_columns);
2182
2183 /*
2184 Test for explicit checks that a column is non-null.
2185
2186 This analysis is ad hoc, and will miss things:
2187
2188 materialize=> create table a(x int, y int);
2189 CREATE TABLE
2190 materialize=> explain with(types) select x from a where (y=x and y is not null) or x is not null;
2191 Optimized Plan
2192 --------------------------------------------------------------------------------------------------------
2193 Explained Query: +
2194 Project (#0) // { types: "(integer?)" } +
2195 Filter ((#0) IS NOT NULL OR ((#1) IS NOT NULL AND (#0 = #1))) // { types: "(integer?, integer?)" }+
2196 Get materialize.public.a // { types: "(integer?, integer?)" } +
2197 +
2198 Source materialize.public.a +
2199 filter=(((#0) IS NOT NULL OR ((#1) IS NOT NULL AND (#0 = #1)))) +
2200
2201 (1 row)
2202 */
2203
2204 if let MirScalarExpr::CallUnary {
2205 func: UnaryFunc::Not(scalar_func::Not),
2206 expr,
2207 } = predicate
2208 {
2209 if let MirScalarExpr::CallUnary {
2210 func: UnaryFunc::IsNull(scalar_func::IsNull),
2211 expr,
2212 } = &**expr
2213 {
2214 if let MirScalarExpr::Column(c, _name) = &**expr {
2215 nonnull_required_columns.insert(*c);
2216 }
2217 }
2218 }
2219 }
2220
2221 nonnull_required_columns
2222}
2223
2224impl CollectionPlan for MirRelationExpr {
2225 /// Collects the global collections that this MIR expression directly depends on, i.e., that it
2226 /// has a `Get` for. (It does _not_ traverse view definitions transitively.)
2227 ///
2228 /// !!!WARNING!!!: this method has an HirRelationExpr counterpart. The two
2229 /// should be kept in sync w.r.t. HIR ⇒ MIR lowering!
2230 fn depends_on_into(&self, out: &mut BTreeSet<GlobalId>) {
2231 if let MirRelationExpr::Get {
2232 id: Id::Global(id), ..
2233 } = self
2234 {
2235 out.insert(*id);
2236 }
2237 self.visit_children(|expr| expr.depends_on_into(out))
2238 }
2239}
2240
2241impl MirRelationExpr {
2242 /// Iterates through references to child expressions.
2243 pub fn children(&self) -> impl DoubleEndedIterator<Item = &Self> {
2244 let mut first = None;
2245 let mut second = None;
2246 let mut rest = None;
2247 let mut last = None;
2248
2249 use MirRelationExpr::*;
2250 match self {
2251 Constant { .. } | Get { .. } => (),
2252 Let { value, body, .. } => {
2253 first = Some(&**value);
2254 second = Some(&**body);
2255 }
2256 LetRec { values, body, .. } => {
2257 rest = Some(values);
2258 last = Some(&**body);
2259 }
2260 Project { input, .. }
2261 | Map { input, .. }
2262 | FlatMap { input, .. }
2263 | Filter { input, .. }
2264 | Reduce { input, .. }
2265 | TopK { input, .. }
2266 | Negate { input }
2267 | Threshold { input }
2268 | ArrangeBy { input, .. } => {
2269 first = Some(&**input);
2270 }
2271 Join { inputs, .. } => {
2272 rest = Some(inputs);
2273 }
2274 Union { base, inputs } => {
2275 first = Some(&**base);
2276 rest = Some(inputs);
2277 }
2278 }
2279
2280 first
2281 .into_iter()
2282 .chain(second)
2283 .chain(rest.into_iter().flatten())
2284 .chain(last)
2285 }
2286
2287 /// Iterates through mutable references to child expressions.
2288 pub fn children_mut(&mut self) -> impl DoubleEndedIterator<Item = &mut Self> {
2289 let mut first = None;
2290 let mut second = None;
2291 let mut rest = None;
2292 let mut last = None;
2293
2294 use MirRelationExpr::*;
2295 match self {
2296 Constant { .. } | Get { .. } => (),
2297 Let { value, body, .. } => {
2298 first = Some(&mut **value);
2299 second = Some(&mut **body);
2300 }
2301 LetRec { values, body, .. } => {
2302 rest = Some(values);
2303 last = Some(&mut **body);
2304 }
2305 Project { input, .. }
2306 | Map { input, .. }
2307 | FlatMap { input, .. }
2308 | Filter { input, .. }
2309 | Reduce { input, .. }
2310 | TopK { input, .. }
2311 | Negate { input }
2312 | Threshold { input }
2313 | ArrangeBy { input, .. } => {
2314 first = Some(&mut **input);
2315 }
2316 Join { inputs, .. } => {
2317 rest = Some(inputs);
2318 }
2319 Union { base, inputs } => {
2320 first = Some(&mut **base);
2321 rest = Some(inputs);
2322 }
2323 }
2324
2325 first
2326 .into_iter()
2327 .chain(second)
2328 .chain(rest.into_iter().flatten())
2329 .chain(last)
2330 }
2331
2332 /// Iterative pre-order visitor.
2333 pub fn visit_pre<'a, F: FnMut(&'a Self)>(&'a self, mut f: F) {
2334 let mut worklist = vec![self];
2335 while let Some(expr) = worklist.pop() {
2336 f(expr);
2337 worklist.extend(expr.children().rev());
2338 }
2339 }
2340
2341 /// Iterative pre-order visitor.
2342 pub fn visit_pre_mut<F: FnMut(&mut Self)>(&mut self, mut f: F) {
2343 let mut worklist = vec![self];
2344 while let Some(expr) = worklist.pop() {
2345 f(expr);
2346 worklist.extend(expr.children_mut().rev());
2347 }
2348 }
2349
2350 /// Return a vector of references to the subtrees of this expression
2351 /// in post-visit order (the last element is `&self`).
2352 pub fn post_order_vec(&self) -> Vec<&Self> {
2353 let mut stack = vec![self];
2354 let mut result = vec![];
2355 while let Some(expr) = stack.pop() {
2356 result.push(expr);
2357 stack.extend(expr.children());
2358 }
2359 result.reverse();
2360 result
2361 }
2362}
2363
2364impl VisitChildren<Self> for MirRelationExpr {
2365 fn visit_children<F>(&self, mut f: F)
2366 where
2367 F: FnMut(&Self),
2368 {
2369 for child in self.children() {
2370 f(child)
2371 }
2372 }
2373
2374 fn visit_mut_children<F>(&mut self, mut f: F)
2375 where
2376 F: FnMut(&mut Self),
2377 {
2378 for child in self.children_mut() {
2379 f(child)
2380 }
2381 }
2382
2383 fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
2384 where
2385 F: FnMut(&Self) -> Result<(), E>,
2386 E: From<RecursionLimitError>,
2387 {
2388 for child in self.children() {
2389 f(child)?
2390 }
2391 Ok(())
2392 }
2393
2394 fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
2395 where
2396 F: FnMut(&mut Self) -> Result<(), E>,
2397 E: From<RecursionLimitError>,
2398 {
2399 for child in self.children_mut() {
2400 f(child)?
2401 }
2402 Ok(())
2403 }
2404}
2405
2406/// Specification for an ordering by a column.
2407#[derive(
2408 Debug,
2409 Clone,
2410 Copy,
2411 Eq,
2412 PartialEq,
2413 Ord,
2414 PartialOrd,
2415 Serialize,
2416 Deserialize,
2417 Hash,
2418 MzReflect
2419)]
2420pub struct ColumnOrder {
2421 /// The column index.
2422 pub column: usize,
2423 /// Whether to sort in descending order.
2424 #[serde(default)]
2425 pub desc: bool,
2426 /// Whether to sort nulls last.
2427 #[serde(default)]
2428 pub nulls_last: bool,
2429}
2430
2431impl Columnation for ColumnOrder {
2432 type InnerRegion = CopyRegion<Self>;
2433}
2434
2435impl<'a, M> fmt::Display for HumanizedExpr<'a, ColumnOrder, M>
2436where
2437 M: HumanizerMode,
2438{
2439 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
2440 // If you modify this, then please also attend to Display for ColumnOrderWithExpr!
2441 write!(
2442 f,
2443 "{} {} {}",
2444 self.child(&self.expr.column),
2445 if self.expr.desc { "desc" } else { "asc" },
2446 if self.expr.nulls_last {
2447 "nulls_last"
2448 } else {
2449 "nulls_first"
2450 },
2451 )
2452 }
2453}
2454
2455/// Describes an aggregation expression.
2456#[derive(
2457 Clone,
2458 Debug,
2459 Eq,
2460 PartialEq,
2461 Ord,
2462 PartialOrd,
2463 Serialize,
2464 Deserialize,
2465 Hash,
2466 MzReflect
2467)]
2468pub struct AggregateExpr {
2469 /// Names the aggregation function.
2470 pub func: AggregateFunc,
2471 /// An expression which extracts from each row the input to `func`.
2472 pub expr: MirScalarExpr,
2473 /// Should the aggregation be applied only to distinct results in each group.
2474 #[serde(default)]
2475 pub distinct: bool,
2476}
2477
2478impl AggregateExpr {
2479 /// Computes the type of this `AggregateExpr`.
2480 pub fn sql_typ(&self, column_types: &[SqlColumnType]) -> SqlColumnType {
2481 self.func.output_sql_type(self.expr.sql_typ(column_types))
2482 }
2483
2484 /// Computes the type of this `AggregateExpr`.
2485 pub fn typ(&self, column_types: &[ReprColumnType]) -> ReprColumnType {
2486 self.func.output_type(self.expr.typ(column_types))
2487 }
2488
2489 /// Returns whether the expression has a constant result.
2490 pub fn is_constant(&self) -> bool {
2491 match self.func {
2492 AggregateFunc::MaxInt16
2493 | AggregateFunc::MaxInt32
2494 | AggregateFunc::MaxInt64
2495 | AggregateFunc::MaxUInt16
2496 | AggregateFunc::MaxUInt32
2497 | AggregateFunc::MaxUInt64
2498 | AggregateFunc::MaxMzTimestamp
2499 | AggregateFunc::MaxFloat32
2500 | AggregateFunc::MaxFloat64
2501 | AggregateFunc::MaxBool
2502 | AggregateFunc::MaxString
2503 | AggregateFunc::MaxDate
2504 | AggregateFunc::MaxTimestamp
2505 | AggregateFunc::MaxTimestampTz
2506 | AggregateFunc::MinInt16
2507 | AggregateFunc::MinInt32
2508 | AggregateFunc::MinInt64
2509 | AggregateFunc::MinUInt16
2510 | AggregateFunc::MinUInt32
2511 | AggregateFunc::MinUInt64
2512 | AggregateFunc::MinMzTimestamp
2513 | AggregateFunc::MinFloat32
2514 | AggregateFunc::MinFloat64
2515 | AggregateFunc::MinBool
2516 | AggregateFunc::MinString
2517 | AggregateFunc::MinDate
2518 | AggregateFunc::MinTimestamp
2519 | AggregateFunc::MinTimestampTz
2520 | AggregateFunc::Any
2521 | AggregateFunc::All
2522 | AggregateFunc::Dummy => self.expr.is_literal(),
2523 AggregateFunc::Count => self.expr.is_literal_null(),
2524 _ => self.expr.is_literal_err(),
2525 }
2526 }
2527
2528 /// Returns an expression that computes `self` on a group that has exactly one row.
2529 /// Instead of performing a `Reduce` with `self`, one can perform a `Map` with the expression
2530 /// returned by `on_unique`, which is cheaper. (See `ReduceElision`.)
2531 pub fn on_unique(&self, input_type: &[ReprColumnType]) -> MirScalarExpr {
2532 match &self.func {
2533 // Count is one if non-null, and zero if null.
2534 AggregateFunc::Count => self
2535 .expr
2536 .clone()
2537 .call_unary(UnaryFunc::IsNull(crate::func::IsNull))
2538 .if_then_else(
2539 MirScalarExpr::literal_ok(Datum::Int64(0), ReprScalarType::Int64),
2540 MirScalarExpr::literal_ok(Datum::Int64(1), ReprScalarType::Int64),
2541 ),
2542
2543 // SumInt16 takes Int16s as input, but outputs Int64s.
2544 AggregateFunc::SumInt16 => self
2545 .expr
2546 .clone()
2547 .call_unary(UnaryFunc::CastInt16ToInt64(scalar_func::CastInt16ToInt64)),
2548
2549 // SumInt32 takes Int32s as input, but outputs Int64s.
2550 AggregateFunc::SumInt32 => self
2551 .expr
2552 .clone()
2553 .call_unary(UnaryFunc::CastInt32ToInt64(scalar_func::CastInt32ToInt64)),
2554
2555 // SumInt64 takes Int64s as input, but outputs numerics.
2556 AggregateFunc::SumInt64 => self.expr.clone().call_unary(UnaryFunc::CastInt64ToNumeric(
2557 scalar_func::CastInt64ToNumeric(Some(NumericMaxScale::ZERO)),
2558 )),
2559
2560 // SumUInt16 takes UInt16s as input, but outputs UInt64s.
2561 AggregateFunc::SumUInt16 => self.expr.clone().call_unary(
2562 UnaryFunc::CastUint16ToUint64(scalar_func::CastUint16ToUint64),
2563 ),
2564
2565 // SumUInt32 takes UInt32s as input, but outputs UInt64s.
2566 AggregateFunc::SumUInt32 => self.expr.clone().call_unary(
2567 UnaryFunc::CastUint32ToUint64(scalar_func::CastUint32ToUint64),
2568 ),
2569
2570 // SumUInt64 takes UInt64s as input, but outputs numerics.
2571 AggregateFunc::SumUInt64 => {
2572 self.expr.clone().call_unary(UnaryFunc::CastUint64ToNumeric(
2573 scalar_func::CastUint64ToNumeric(Some(NumericMaxScale::ZERO)),
2574 ))
2575 }
2576
2577 // JsonbAgg takes _anything_ as input, but must output a Jsonb array.
2578 AggregateFunc::JsonbAgg { .. } => MirScalarExpr::call_variadic(
2579 JsonbBuildArray,
2580 vec![
2581 self.expr
2582 .clone()
2583 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0))),
2584 ],
2585 ),
2586
2587 // JsonbAgg takes _anything_ as input, but must output a Jsonb object.
2588 AggregateFunc::JsonbObjectAgg { .. } => {
2589 let record = self
2590 .expr
2591 .clone()
2592 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2593 MirScalarExpr::call_variadic(
2594 JsonbBuildObject,
2595 (0..2)
2596 .map(|i| {
2597 record
2598 .clone()
2599 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(i)))
2600 })
2601 .collect(),
2602 )
2603 }
2604
2605 AggregateFunc::MapAgg { value_type, .. } => {
2606 let record = self
2607 .expr
2608 .clone()
2609 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2610 MirScalarExpr::call_variadic(
2611 MapBuild {
2612 value_type: value_type.clone(),
2613 },
2614 (0..2)
2615 .map(|i| {
2616 record
2617 .clone()
2618 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(i)))
2619 })
2620 .collect(),
2621 )
2622 }
2623
2624 // StringAgg takes nested records of strings and outputs a string
2625 AggregateFunc::StringAgg { .. } => self
2626 .expr
2627 .clone()
2628 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)))
2629 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0))),
2630
2631 // ListConcat and ArrayConcat take a single level of records and output a list containing exactly 1 element
2632 AggregateFunc::ListConcat { .. } | AggregateFunc::ArrayConcat { .. } => self
2633 .expr
2634 .clone()
2635 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0))),
2636
2637 // RowNumber, Rank, DenseRank take a list of records and output a list containing exactly 1 element
2638 AggregateFunc::RowNumber { .. } => {
2639 self.on_unique_ranking_window_funcs(input_type, "?row_number?")
2640 }
2641 AggregateFunc::Rank { .. } => self.on_unique_ranking_window_funcs(input_type, "?rank?"),
2642 AggregateFunc::DenseRank { .. } => {
2643 self.on_unique_ranking_window_funcs(input_type, "?dense_rank?")
2644 }
2645
2646 // The input type for LagLead is ((OriginalRow, (InputValue, Offset, Default)), OrderByExprs...)
2647 AggregateFunc::LagLead { lag_lead, .. } => {
2648 let tuple = self
2649 .expr
2650 .clone()
2651 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2652
2653 // Get the overall return type
2654 let return_type_with_orig_row = self
2655 .typ(input_type)
2656 .scalar_type
2657 .unwrap_list_element_type()
2658 .clone();
2659 let lag_lead_return_type =
2660 return_type_with_orig_row.unwrap_record_element_type()[0].clone();
2661
2662 // Extract the original row
2663 let original_row = tuple
2664 .clone()
2665 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2666
2667 // Extract the encoded args
2668 let encoded_args =
2669 tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
2670
2671 let (result_expr, column_name) =
2672 Self::on_unique_lag_lead(lag_lead, encoded_args, lag_lead_return_type.clone());
2673
2674 MirScalarExpr::call_variadic(
2675 ListCreate {
2676 elem_type: SqlScalarType::from_repr(&return_type_with_orig_row),
2677 },
2678 vec![MirScalarExpr::call_variadic(
2679 RecordCreate {
2680 field_names: vec![column_name, ColumnName::from("?record?")],
2681 },
2682 vec![result_expr, original_row],
2683 )],
2684 )
2685 }
2686
2687 // The input type for FirstValue is ((OriginalRow, InputValue), OrderByExprs...)
2688 AggregateFunc::FirstValue { window_frame, .. } => {
2689 let tuple = self
2690 .expr
2691 .clone()
2692 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2693
2694 // Get the overall return type
2695 let return_type_with_orig_row = self
2696 .typ(input_type)
2697 .scalar_type
2698 .unwrap_list_element_type()
2699 .clone();
2700 let first_value_return_type =
2701 return_type_with_orig_row.unwrap_record_element_type()[0].clone();
2702
2703 // Extract the original row
2704 let original_row = tuple
2705 .clone()
2706 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2707
2708 // Extract the input value
2709 let arg = tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
2710
2711 let (result_expr, column_name) = Self::on_unique_first_value_last_value(
2712 window_frame,
2713 arg,
2714 first_value_return_type,
2715 );
2716
2717 MirScalarExpr::call_variadic(
2718 ListCreate {
2719 elem_type: SqlScalarType::from_repr(&return_type_with_orig_row),
2720 },
2721 vec![MirScalarExpr::call_variadic(
2722 RecordCreate {
2723 field_names: vec![column_name, ColumnName::from("?record?")],
2724 },
2725 vec![result_expr, original_row],
2726 )],
2727 )
2728 }
2729
2730 // The input type for LastValue is ((OriginalRow, InputValue), OrderByExprs...)
2731 AggregateFunc::LastValue { window_frame, .. } => {
2732 let tuple = self
2733 .expr
2734 .clone()
2735 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2736
2737 // Get the overall return type
2738 let return_type_with_orig_row = self
2739 .typ(input_type)
2740 .scalar_type
2741 .unwrap_list_element_type()
2742 .clone();
2743 let last_value_return_type =
2744 return_type_with_orig_row.unwrap_record_element_type()[0].clone();
2745
2746 // Extract the original row
2747 let original_row = tuple
2748 .clone()
2749 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2750
2751 // Extract the input value
2752 let arg = tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
2753
2754 let (result_expr, column_name) = Self::on_unique_first_value_last_value(
2755 window_frame,
2756 arg,
2757 last_value_return_type,
2758 );
2759
2760 MirScalarExpr::call_variadic(
2761 ListCreate {
2762 elem_type: SqlScalarType::from_repr(&return_type_with_orig_row),
2763 },
2764 vec![MirScalarExpr::call_variadic(
2765 RecordCreate {
2766 field_names: vec![column_name, ColumnName::from("?record?")],
2767 },
2768 vec![result_expr, original_row],
2769 )],
2770 )
2771 }
2772
2773 // The input type for window aggs is ((OriginalRow, InputValue), OrderByExprs...)
2774 // See an example MIR in `window_func_applied_to`.
2775 AggregateFunc::WindowAggregate {
2776 wrapped_aggregate,
2777 window_frame,
2778 order_by: _,
2779 } => {
2780 // TODO: deduplicate code between the various window function cases.
2781
2782 let tuple = self
2783 .expr
2784 .clone()
2785 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2786
2787 // Get the overall return type
2788 let return_type = self
2789 .typ(input_type)
2790 .scalar_type
2791 .unwrap_list_element_type()
2792 .clone();
2793 let window_agg_return_type = return_type.unwrap_record_element_type()[0].clone();
2794
2795 // Extract the original row
2796 let original_row = tuple
2797 .clone()
2798 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2799
2800 // Extract the input value
2801 let arg_expr = tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
2802
2803 let (result, column_name) = Self::on_unique_window_agg(
2804 window_frame,
2805 arg_expr,
2806 input_type,
2807 window_agg_return_type,
2808 wrapped_aggregate,
2809 );
2810
2811 MirScalarExpr::call_variadic(
2812 ListCreate {
2813 elem_type: SqlScalarType::from_repr(&return_type),
2814 },
2815 vec![MirScalarExpr::call_variadic(
2816 RecordCreate {
2817 field_names: vec![column_name, ColumnName::from("?record?")],
2818 },
2819 vec![result, original_row],
2820 )],
2821 )
2822 }
2823
2824 // The input type is ((OriginalRow, (Arg1, Arg2, ...)), OrderByExprs...)
2825 AggregateFunc::FusedWindowAggregate {
2826 wrapped_aggregates,
2827 order_by: _,
2828 window_frame,
2829 } => {
2830 // Throw away OrderByExprs
2831 let tuple = self
2832 .expr
2833 .clone()
2834 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2835
2836 // Extract the original row
2837 let original_row = tuple
2838 .clone()
2839 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2840
2841 // Extract the args of the fused call
2842 let all_args = tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
2843
2844 let return_type_with_orig_row = self
2845 .typ(input_type)
2846 .scalar_type
2847 .unwrap_list_element_type()
2848 .clone();
2849
2850 let all_func_return_types =
2851 return_type_with_orig_row.unwrap_record_element_type()[0].clone();
2852 let mut func_result_exprs = Vec::new();
2853 let mut col_names = Vec::new();
2854 for (idx, wrapped_aggr) in wrapped_aggregates.iter().enumerate() {
2855 let arg = all_args
2856 .clone()
2857 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(idx)));
2858 let return_type =
2859 all_func_return_types.unwrap_record_element_type()[idx].clone();
2860 let (result, column_name) = Self::on_unique_window_agg(
2861 window_frame,
2862 arg,
2863 input_type,
2864 return_type,
2865 wrapped_aggr,
2866 );
2867 func_result_exprs.push(result);
2868 col_names.push(column_name);
2869 }
2870
2871 MirScalarExpr::call_variadic(
2872 ListCreate {
2873 elem_type: SqlScalarType::from_repr(&return_type_with_orig_row),
2874 },
2875 vec![MirScalarExpr::call_variadic(
2876 RecordCreate {
2877 field_names: vec![
2878 ColumnName::from("?fused_window_aggr?"),
2879 ColumnName::from("?record?"),
2880 ],
2881 },
2882 vec![
2883 MirScalarExpr::call_variadic(
2884 RecordCreate {
2885 field_names: col_names,
2886 },
2887 func_result_exprs,
2888 ),
2889 original_row,
2890 ],
2891 )],
2892 )
2893 }
2894
2895 // The input type is ((OriginalRow, (Args1, Args2, ...)), OrderByExprs...)
2896 AggregateFunc::FusedValueWindowFunc {
2897 funcs,
2898 order_by: outer_order_by,
2899 } => {
2900 // Throw away OrderByExprs
2901 let tuple = self
2902 .expr
2903 .clone()
2904 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2905
2906 // Extract the original row
2907 let original_row = tuple
2908 .clone()
2909 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2910
2911 // Extract the encoded args of the fused call
2912 let all_encoded_args =
2913 tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
2914
2915 let return_type_with_orig_row = self
2916 .typ(input_type)
2917 .scalar_type
2918 .unwrap_list_element_type()
2919 .clone();
2920
2921 let all_func_return_types =
2922 return_type_with_orig_row.unwrap_record_element_type()[0].clone();
2923 let mut func_result_exprs = Vec::new();
2924 let mut col_names = Vec::new();
2925 for (idx, func) in funcs.iter().enumerate() {
2926 let args_for_func = all_encoded_args
2927 .clone()
2928 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(idx)));
2929 let return_type_for_func =
2930 all_func_return_types.unwrap_record_element_type()[idx].clone();
2931 let (result, column_name) = match func {
2932 AggregateFunc::LagLead {
2933 lag_lead,
2934 order_by,
2935 ignore_nulls: _,
2936 } => {
2937 assert_eq!(order_by, outer_order_by);
2938 Self::on_unique_lag_lead(lag_lead, args_for_func, return_type_for_func)
2939 }
2940 AggregateFunc::FirstValue {
2941 window_frame,
2942 order_by,
2943 } => {
2944 assert_eq!(order_by, outer_order_by);
2945 Self::on_unique_first_value_last_value(
2946 window_frame,
2947 args_for_func,
2948 return_type_for_func,
2949 )
2950 }
2951 AggregateFunc::LastValue {
2952 window_frame,
2953 order_by,
2954 } => {
2955 assert_eq!(order_by, outer_order_by);
2956 Self::on_unique_first_value_last_value(
2957 window_frame,
2958 args_for_func,
2959 return_type_for_func,
2960 )
2961 }
2962 _ => panic!("unknown function in FusedValueWindowFunc"),
2963 };
2964 func_result_exprs.push(result);
2965 col_names.push(column_name);
2966 }
2967
2968 MirScalarExpr::call_variadic(
2969 ListCreate {
2970 elem_type: SqlScalarType::from_repr(&return_type_with_orig_row),
2971 },
2972 vec![MirScalarExpr::call_variadic(
2973 RecordCreate {
2974 field_names: vec![
2975 ColumnName::from("?fused_value_window_func?"),
2976 ColumnName::from("?record?"),
2977 ],
2978 },
2979 vec![
2980 MirScalarExpr::call_variadic(
2981 RecordCreate {
2982 field_names: col_names,
2983 },
2984 func_result_exprs,
2985 ),
2986 original_row,
2987 ],
2988 )],
2989 )
2990 }
2991
2992 // All other variants should return the argument to the aggregation.
2993 AggregateFunc::MaxNumeric
2994 | AggregateFunc::MaxInt16
2995 | AggregateFunc::MaxInt32
2996 | AggregateFunc::MaxInt64
2997 | AggregateFunc::MaxUInt16
2998 | AggregateFunc::MaxUInt32
2999 | AggregateFunc::MaxUInt64
3000 | AggregateFunc::MaxMzTimestamp
3001 | AggregateFunc::MaxFloat32
3002 | AggregateFunc::MaxFloat64
3003 | AggregateFunc::MaxBool
3004 | AggregateFunc::MaxString
3005 | AggregateFunc::MaxDate
3006 | AggregateFunc::MaxTimestamp
3007 | AggregateFunc::MaxTimestampTz
3008 | AggregateFunc::MaxInterval
3009 | AggregateFunc::MaxTime
3010 | AggregateFunc::MinNumeric
3011 | AggregateFunc::MinInt16
3012 | AggregateFunc::MinInt32
3013 | AggregateFunc::MinInt64
3014 | AggregateFunc::MinUInt16
3015 | AggregateFunc::MinUInt32
3016 | AggregateFunc::MinUInt64
3017 | AggregateFunc::MinMzTimestamp
3018 | AggregateFunc::MinFloat32
3019 | AggregateFunc::MinFloat64
3020 | AggregateFunc::MinBool
3021 | AggregateFunc::MinString
3022 | AggregateFunc::MinDate
3023 | AggregateFunc::MinTimestamp
3024 | AggregateFunc::MinTimestampTz
3025 | AggregateFunc::MinInterval
3026 | AggregateFunc::MinTime
3027 | AggregateFunc::SumFloat32
3028 | AggregateFunc::SumFloat64
3029 | AggregateFunc::SumNumeric
3030 | AggregateFunc::Any
3031 | AggregateFunc::All
3032 | AggregateFunc::Dummy => self.expr.clone(),
3033 }
3034 }
3035
3036 /// `on_unique` for ROW_NUMBER, RANK, DENSE_RANK
3037 fn on_unique_ranking_window_funcs(
3038 &self,
3039 input_type: &[ReprColumnType],
3040 col_name: &str,
3041 ) -> MirScalarExpr {
3042 let sql_input_type: Vec<SqlColumnType> =
3043 input_type.iter().map(SqlColumnType::from_repr).collect();
3044 let list = self
3045 .expr
3046 .clone()
3047 // extract the list within the record
3048 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
3049
3050 // extract the expression within the list
3051 let record = MirScalarExpr::call_variadic(
3052 ListIndex,
3053 vec![
3054 list,
3055 MirScalarExpr::literal_ok(Datum::Int64(1), ReprScalarType::Int64),
3056 ],
3057 );
3058
3059 MirScalarExpr::call_variadic(
3060 ListCreate {
3061 elem_type: self
3062 .sql_typ(&sql_input_type)
3063 .scalar_type
3064 .unwrap_list_element_type()
3065 .clone(),
3066 },
3067 vec![MirScalarExpr::call_variadic(
3068 RecordCreate {
3069 field_names: vec![ColumnName::from(col_name), ColumnName::from("?record?")],
3070 },
3071 vec![
3072 MirScalarExpr::literal_ok(Datum::Int64(1), ReprScalarType::Int64),
3073 record,
3074 ],
3075 )],
3076 )
3077 }
3078
3079 /// `on_unique` for `lag` and `lead`
3080 fn on_unique_lag_lead(
3081 lag_lead: &LagLeadType,
3082 encoded_args: MirScalarExpr,
3083 return_type: ReprScalarType,
3084 ) -> (MirScalarExpr, ColumnName) {
3085 let expr = encoded_args
3086 .clone()
3087 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
3088 let offset = encoded_args
3089 .clone()
3090 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
3091 let default_value =
3092 encoded_args.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(2)));
3093
3094 // In this case, the window always has only one element, so if the offset is not null and
3095 // not zero, the default value should be returned instead.
3096 let value = offset
3097 .clone()
3098 .call_binary(
3099 MirScalarExpr::literal_ok(Datum::Int32(0), ReprScalarType::Int32),
3100 crate::func::Eq,
3101 )
3102 .if_then_else(expr, default_value);
3103 let result_expr = offset
3104 .call_unary(UnaryFunc::IsNull(crate::func::IsNull))
3105 .if_then_else(MirScalarExpr::literal_null(return_type), value);
3106
3107 let column_name = ColumnName::from(match lag_lead {
3108 LagLeadType::Lag => "?lag?",
3109 LagLeadType::Lead => "?lead?",
3110 });
3111
3112 (result_expr, column_name)
3113 }
3114
3115 /// `on_unique` for `first_value` and `last_value`
3116 fn on_unique_first_value_last_value(
3117 window_frame: &WindowFrame,
3118 arg: MirScalarExpr,
3119 return_type: ReprScalarType,
3120 ) -> (MirScalarExpr, ColumnName) {
3121 // If the window frame includes the current (single) row, return its value, null otherwise
3122 let result_expr = if window_frame.includes_current_row() {
3123 arg
3124 } else {
3125 MirScalarExpr::literal_null(return_type)
3126 };
3127 (result_expr, ColumnName::from("?first_value?"))
3128 }
3129
3130 /// `on_unique` for window aggregations
3131 fn on_unique_window_agg(
3132 window_frame: &WindowFrame,
3133 arg_expr: MirScalarExpr,
3134 input_type: &[ReprColumnType],
3135 return_type: ReprScalarType,
3136 wrapped_aggr: &AggregateFunc,
3137 ) -> (MirScalarExpr, ColumnName) {
3138 // If the window frame includes the current (single) row, evaluate the wrapped aggregate on
3139 // that row. Otherwise, return the default value for the aggregate.
3140 let result_expr = if window_frame.includes_current_row() {
3141 AggregateExpr {
3142 func: wrapped_aggr.clone(),
3143 expr: arg_expr,
3144 distinct: false, // We have just one input element; DISTINCT doesn't matter.
3145 }
3146 .on_unique(input_type)
3147 } else {
3148 MirScalarExpr::literal_ok(wrapped_aggr.default(), return_type)
3149 };
3150 (result_expr, ColumnName::from("?window_agg?"))
3151 }
3152
3153 /// Returns whether the expression is COUNT(*) or not. Note that
3154 /// when we define the count builtin in sql::func, we convert
3155 /// COUNT(*) to COUNT(true), making it indistinguishable from
3156 /// literal COUNT(true), but we prefer to consider this as the
3157 /// former.
3158 ///
3159 /// (HIR has the same `is_count_asterisk`.)
3160 pub fn is_count_asterisk(&self) -> bool {
3161 self.func == AggregateFunc::Count && self.expr.is_literal_true() && !self.distinct
3162 }
3163}
3164
3165/// Describe a join implementation in dataflow.
3166#[derive(
3167 Clone,
3168 Debug,
3169 Eq,
3170 PartialEq,
3171 Ord,
3172 PartialOrd,
3173 Serialize,
3174 Deserialize,
3175 Hash,
3176 MzReflect
3177)]
3178pub enum JoinImplementation {
3179 /// Perform a sequence of binary differential dataflow joins.
3180 ///
3181 /// The first argument indicates
3182 /// 1) the index of the starting collection,
3183 /// 2) if it should be arranged, the keys to arrange it by, and
3184 /// 3) the characteristics of the starting collection (for EXPLAINing).
3185 /// The sequence that follows lists other relation indexes, and the key for
3186 /// the arrangement we should use when joining it in.
3187 /// The JoinInputCharacteristics are for EXPLAINing the characteristics that
3188 /// were used for join ordering.
3189 ///
3190 /// Each collection index should occur exactly once, either as the starting collection
3191 /// or somewhere in the list.
3192 Differential(
3193 (
3194 usize,
3195 Option<Vec<MirScalarExpr>>,
3196 Option<JoinInputCharacteristics>,
3197 ),
3198 Vec<(usize, Vec<MirScalarExpr>, Option<JoinInputCharacteristics>)>,
3199 ),
3200 /// Perform independent delta query dataflows for each input.
3201 ///
3202 /// The argument is a sequence of plans, for the input collections in order.
3203 /// Each plan starts from the corresponding index, and then in sequence joins
3204 /// against collections identified by index and with the specified arrangement key.
3205 /// The JoinInputCharacteristics are for EXPLAINing the characteristics that were
3206 /// used for join ordering.
3207 DeltaQuery(Vec<Vec<(usize, Vec<MirScalarExpr>, Option<JoinInputCharacteristics>)>>),
3208 /// Join a user-created index with a constant collection to speed up the evaluation of a
3209 /// predicate such as `(f1 = 3 AND f2 = 5) OR (f1 = 7 AND f2 = 9)`.
3210 /// This gets translated to a Differential join during MIR -> LIR lowering, but we still want
3211 /// to represent it in MIR, because the fast path detection wants to match on this.
3212 ///
3213 /// Consists of (`<coll_id>`, `<index_id>`, `<index_key>`, `<constants>`)
3214 IndexedFilter(
3215 GlobalId,
3216 GlobalId,
3217 Vec<MirScalarExpr>,
3218 #[mzreflect(ignore)] Vec<Row>,
3219 ),
3220 /// No implementation yet selected.
3221 Unimplemented,
3222}
3223
3224impl Default for JoinImplementation {
3225 fn default() -> Self {
3226 JoinImplementation::Unimplemented
3227 }
3228}
3229
3230impl JoinImplementation {
3231 /// Returns `true` iff the value is not [`JoinImplementation::Unimplemented`].
3232 pub fn is_implemented(&self) -> bool {
3233 match self {
3234 Self::Unimplemented => false,
3235 _ => true,
3236 }
3237 }
3238
3239 /// Returns an optional implementation name if the value is not [`JoinImplementation::Unimplemented`].
3240 pub fn name(&self) -> Option<&'static str> {
3241 match self {
3242 Self::Differential(..) => Some("differential"),
3243 Self::DeltaQuery(..) => Some("delta"),
3244 Self::IndexedFilter(..) => Some("indexed_filter"),
3245 Self::Unimplemented => None,
3246 }
3247 }
3248}
3249
3250/// Characteristics of a join order candidate collection.
3251///
3252/// A candidate is described by a collection and a key, and may have various liabilities.
3253/// Primarily, the candidate may risk substantial inflation of records, which is something
3254/// that concerns us greatly. Additionally, the candidate may be unarranged, and we would
3255/// prefer candidates that do not require additional memory. Finally, we prefer lower id
3256/// collections in the interest of consistent tie-breaking. For more characteristics, see
3257/// comments on individual fields.
3258///
3259/// This has more than one version. `new` instantiates the appropriate version based on a
3260/// feature flag.
3261#[derive(
3262 Eq,
3263 PartialEq,
3264 Ord,
3265 PartialOrd,
3266 Debug,
3267 Clone,
3268 Serialize,
3269 Deserialize,
3270 Hash,
3271 MzReflect
3272)]
3273pub enum JoinInputCharacteristics {
3274 /// Old version, with `enable_join_prioritize_arranged` turned off.
3275 V1(JoinInputCharacteristicsV1),
3276 /// Newer version, with `enable_join_prioritize_arranged` turned on.
3277 V2(JoinInputCharacteristicsV2),
3278}
3279
3280impl JoinInputCharacteristics {
3281 /// Creates a new instance with the given characteristics.
3282 pub fn new(
3283 unique_key: bool,
3284 key_length: usize,
3285 arranged: bool,
3286 cardinality: Option<usize>,
3287 filters: FilterCharacteristics,
3288 input: usize,
3289 enable_join_prioritize_arranged: bool,
3290 ) -> Self {
3291 if enable_join_prioritize_arranged {
3292 Self::V2(JoinInputCharacteristicsV2::new(
3293 unique_key,
3294 key_length,
3295 arranged,
3296 cardinality,
3297 filters,
3298 input,
3299 ))
3300 } else {
3301 Self::V1(JoinInputCharacteristicsV1::new(
3302 unique_key,
3303 key_length,
3304 arranged,
3305 cardinality,
3306 filters,
3307 input,
3308 ))
3309 }
3310 }
3311
3312 /// Turns the instance into a String to be printed in EXPLAIN.
3313 pub fn explain(&self) -> String {
3314 match self {
3315 Self::V1(jic) => jic.explain(),
3316 Self::V2(jic) => jic.explain(),
3317 }
3318 }
3319
3320 /// Whether the join input described by `self` is arranged.
3321 pub fn arranged(&self) -> bool {
3322 match self {
3323 Self::V1(jic) => jic.arranged,
3324 Self::V2(jic) => jic.arranged,
3325 }
3326 }
3327
3328 /// Returns the `FilterCharacteristics` for the join input described by `self`.
3329 pub fn filters(&mut self) -> &mut FilterCharacteristics {
3330 match self {
3331 Self::V1(jic) => &mut jic.filters,
3332 Self::V2(jic) => &mut jic.filters,
3333 }
3334 }
3335}
3336
3337/// Newer version of `JoinInputCharacteristics`, with `enable_join_prioritize_arranged` turned on.
3338#[derive(
3339 Eq,
3340 PartialEq,
3341 Ord,
3342 PartialOrd,
3343 Debug,
3344 Clone,
3345 Serialize,
3346 Deserialize,
3347 Hash,
3348 MzReflect
3349)]
3350pub struct JoinInputCharacteristicsV2 {
3351 /// An excellent indication that record count will not increase.
3352 pub unique_key: bool,
3353 /// Cross joins are bad.
3354 /// (`key_length > 0` also implies that it is not a cross join. However, we need to note cross
3355 /// joins in a separate field, because not being a cross join is more important than `arranged`,
3356 /// but otherwise `key_length` is less important than `arranged`.)
3357 pub not_cross: bool,
3358 /// Indicates that there will be no additional in-memory footprint.
3359 pub arranged: bool,
3360 /// A weaker signal that record count will not increase.
3361 pub key_length: usize,
3362 /// Estimated cardinality (lower is better)
3363 pub cardinality: Option<std::cmp::Reverse<usize>>,
3364 /// Characteristics of the filter that is applied at this input.
3365 pub filters: FilterCharacteristics,
3366 /// We want to prefer input earlier in the input list, for stability of ordering.
3367 pub input: std::cmp::Reverse<usize>,
3368}
3369
3370impl JoinInputCharacteristicsV2 {
3371 /// Creates a new instance with the given characteristics.
3372 pub fn new(
3373 unique_key: bool,
3374 key_length: usize,
3375 arranged: bool,
3376 cardinality: Option<usize>,
3377 filters: FilterCharacteristics,
3378 input: usize,
3379 ) -> Self {
3380 Self {
3381 unique_key,
3382 not_cross: key_length > 0,
3383 arranged,
3384 key_length,
3385 cardinality: cardinality.map(std::cmp::Reverse),
3386 filters,
3387 input: std::cmp::Reverse(input),
3388 }
3389 }
3390
3391 /// Turns the instance into a String to be printed in EXPLAIN.
3392 pub fn explain(&self) -> String {
3393 let mut e = "".to_owned();
3394 if self.unique_key {
3395 e.push_str("U");
3396 }
3397 // Don't need to print `not_cross`, because that is visible in the printed key.
3398 // if !self.not_cross {
3399 // e.push_str("C");
3400 // }
3401 for _ in 0..self.key_length {
3402 e.push_str("K");
3403 }
3404 if self.arranged {
3405 e.push_str("A");
3406 }
3407 if let Some(std::cmp::Reverse(cardinality)) = self.cardinality {
3408 e.push_str(&format!("|{cardinality}|"));
3409 }
3410 e.push_str(&self.filters.explain());
3411 e
3412 }
3413}
3414
3415/// Old version of `JoinInputCharacteristics`, with `enable_join_prioritize_arranged` turned off.
3416#[derive(
3417 Eq,
3418 PartialEq,
3419 Ord,
3420 PartialOrd,
3421 Debug,
3422 Clone,
3423 Serialize,
3424 Deserialize,
3425 Hash,
3426 MzReflect
3427)]
3428pub struct JoinInputCharacteristicsV1 {
3429 /// An excellent indication that record count will not increase.
3430 pub unique_key: bool,
3431 /// A weaker signal that record count will not increase.
3432 pub key_length: usize,
3433 /// Indicates that there will be no additional in-memory footprint.
3434 pub arranged: bool,
3435 /// Estimated cardinality (lower is better)
3436 pub cardinality: Option<std::cmp::Reverse<usize>>,
3437 /// Characteristics of the filter that is applied at this input.
3438 pub filters: FilterCharacteristics,
3439 /// We want to prefer input earlier in the input list, for stability of ordering.
3440 pub input: std::cmp::Reverse<usize>,
3441}
3442
3443impl JoinInputCharacteristicsV1 {
3444 /// Creates a new instance with the given characteristics.
3445 pub fn new(
3446 unique_key: bool,
3447 key_length: usize,
3448 arranged: bool,
3449 cardinality: Option<usize>,
3450 filters: FilterCharacteristics,
3451 input: usize,
3452 ) -> Self {
3453 Self {
3454 unique_key,
3455 key_length,
3456 arranged,
3457 cardinality: cardinality.map(std::cmp::Reverse),
3458 filters,
3459 input: std::cmp::Reverse(input),
3460 }
3461 }
3462
3463 /// Turns the instance into a String to be printed in EXPLAIN.
3464 pub fn explain(&self) -> String {
3465 let mut e = "".to_owned();
3466 if self.unique_key {
3467 e.push_str("U");
3468 }
3469 for _ in 0..self.key_length {
3470 e.push_str("K");
3471 }
3472 if self.arranged {
3473 e.push_str("A");
3474 }
3475 if let Some(std::cmp::Reverse(cardinality)) = self.cardinality {
3476 e.push_str(&format!("|{cardinality}|"));
3477 }
3478 e.push_str(&self.filters.explain());
3479 e
3480 }
3481}
3482
3483/// Instructions for finishing the result of a query.
3484///
3485/// The primary reason for the existence of this structure and attendant code
3486/// is that SQL's ORDER BY requires sorting rows (as already implied by the
3487/// keywords), whereas much of the rest of SQL is defined in terms of unordered
3488/// multisets. But as it turns out, the same idea can be used to optimize
3489/// trivial peeks.
3490///
3491/// The generic parameters are for accommodating prepared statement parameters in
3492/// `limit` and `offset`: the planner can hold these fields as HirScalarExpr long enough to call
3493/// `bind_parameters` on them.
3494#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
3495pub struct RowSetFinishing<L = NonNeg<i64>, O = usize> {
3496 /// Order rows by the given columns.
3497 pub order_by: Vec<ColumnOrder>,
3498 /// Include only as many rows (after offset).
3499 pub limit: Option<L>,
3500 /// Omit as many rows.
3501 pub offset: O,
3502 /// Include only given columns.
3503 pub project: Vec<usize>,
3504}
3505
3506impl<L> RowSetFinishing<L> {
3507 /// Returns a trivial finishing, i.e., that does nothing to the result set.
3508 pub fn trivial(arity: usize) -> RowSetFinishing<L> {
3509 RowSetFinishing {
3510 order_by: Vec::new(),
3511 limit: None,
3512 offset: 0,
3513 project: (0..arity).collect(),
3514 }
3515 }
3516 /// True if the finishing does nothing to any result set.
3517 pub fn is_trivial(&self, arity: usize) -> bool {
3518 self.limit.is_none()
3519 && self.order_by.is_empty()
3520 && self.offset == 0
3521 && self.project.iter().copied().eq(0..arity)
3522 }
3523 /// True if the finishing does not require an ORDER BY.
3524 ///
3525 /// LIMIT and OFFSET without an ORDER BY _are_ streamable: without an
3526 /// explicit ordering we will skip an arbitrary bag of elements and return
3527 /// the first arbitrary elements in the remaining bag. The result semantics
3528 /// are still correct but maybe surprising for some users.
3529 pub fn is_streamable(&self, arity: usize) -> bool {
3530 self.order_by.is_empty() && self.project.iter().copied().eq(0..arity)
3531 }
3532}
3533
3534impl RowSetFinishing<NonNeg<i64>, usize> {
3535 /// The number of rows needed from before the finishing to evaluate the finishing:
3536 /// offset + limit.
3537 ///
3538 /// If it returns None, then we need all the rows.
3539 pub fn num_rows_needed(&self) -> Option<usize> {
3540 self.limit
3541 .as_ref()
3542 .map(|l| usize::cast_from(u64::from(l.clone())) + self.offset)
3543 }
3544}
3545
3546impl RowSetFinishing {
3547 /// Applies finishing actions to a [`RowCollection`], and reports the total
3548 /// time it took to run.
3549 ///
3550 /// Returns a [`RowCollectionIter`] that contains all of the response data, as
3551 /// well as the size of the response in bytes.
3552 pub fn finish(
3553 &self,
3554 rows: RowCollection,
3555 max_result_size: u64,
3556 max_returned_query_size: Option<u64>,
3557 duration_histogram: &Histogram,
3558 ) -> Result<(RowCollectionIter, usize), String> {
3559 let now = Instant::now();
3560 let result = self.finish_inner(rows, max_result_size, max_returned_query_size);
3561 let duration = now.elapsed();
3562 duration_histogram.observe(duration.as_secs_f64());
3563
3564 result
3565 }
3566
3567 /// Implementation for [`RowSetFinishing::finish`].
3568 fn finish_inner(
3569 &self,
3570 rows: RowCollection,
3571 max_result_size: u64,
3572 max_returned_query_size: Option<u64>,
3573 ) -> Result<(RowCollectionIter, usize), String> {
3574 // How much additional memory is required to make a sorted view.
3575 let sorted_view_mem = rows.entries().saturating_mul(std::mem::size_of::<usize>());
3576 let required_memory = rows.byte_len().saturating_add(sorted_view_mem);
3577
3578 // Bail if creating the sorted view would require us to use too much memory.
3579 if required_memory > usize::cast_from(max_result_size) {
3580 let max_bytes = ByteSize::b(max_result_size);
3581 return Err(format!("result exceeds max size of {max_bytes}",));
3582 }
3583
3584 let sorted_view = rows;
3585 let mut iter = sorted_view
3586 .into_row_iter()
3587 .apply_offset(self.offset)
3588 .with_projection(self.project.clone());
3589
3590 if let Some(limit) = self.limit {
3591 let limit = u64::from(limit);
3592 let limit = usize::cast_from(limit);
3593 iter = iter.with_limit(limit);
3594 };
3595
3596 // TODO(parkmycar): Re-think how we can calculate the total response size without
3597 // having to iterate through the entire collection of Rows, while still
3598 // respecting the LIMIT, OFFSET, and projections.
3599 //
3600 // Note: It feels a bit bad always calculating the response size, but we almost
3601 // always need it to either check the `max_returned_query_size`, or for reporting
3602 // in the query history.
3603 let response_size: usize = iter.clone().map(|row| row.data().len()).sum();
3604
3605 // Bail if we would end up returning more data to the client than they can support.
3606 if let Some(max) = max_returned_query_size {
3607 if response_size > usize::cast_from(max) {
3608 let max_bytes = ByteSize::b(max);
3609 return Err(format!("result exceeds max size of {max_bytes}"));
3610 }
3611 }
3612
3613 Ok((iter, response_size))
3614 }
3615}
3616
3617/// A [RowSetFinishing] that can be repeatedly applied to batches of updates (in
3618/// a [RowCollection]) and keeps track of the remaining limit, offset, and cap
3619/// on query result size.
3620#[derive(Debug)]
3621pub struct RowSetFinishingIncremental {
3622 /// Include only as many rows (after offset).
3623 pub remaining_limit: Option<usize>,
3624 /// Omit as many rows.
3625 pub remaining_offset: usize,
3626 /// The maximum allowed result size, as requested by the client.
3627 pub max_returned_query_size: Option<u64>,
3628 /// Tracks our remaining allowed budget for result size.
3629 pub remaining_max_returned_query_size: Option<u64>,
3630 /// Include only given columns.
3631 pub project: Vec<usize>,
3632}
3633
3634impl RowSetFinishingIncremental {
3635 /// Turns the given [RowSetFinishing] into a [RowSetFinishingIncremental].
3636 /// Can only be used when [is_streamable](RowSetFinishing::is_streamable) is
3637 /// `true`.
3638 ///
3639 /// # Panics
3640 ///
3641 /// Panics if the result is not streamable, that is it has an ORDER BY.
3642 pub fn new(
3643 offset: usize,
3644 limit: Option<NonNeg<i64>>,
3645 project: Vec<usize>,
3646 max_returned_query_size: Option<u64>,
3647 ) -> Self {
3648 let limit = limit.map(|l| {
3649 let l = u64::from(l);
3650 let l = usize::cast_from(l);
3651 l
3652 });
3653
3654 RowSetFinishingIncremental {
3655 remaining_limit: limit,
3656 remaining_offset: offset,
3657 max_returned_query_size,
3658 remaining_max_returned_query_size: max_returned_query_size,
3659 project,
3660 }
3661 }
3662
3663 /// Applies finishing actions to the given [`RowCollection`], and reports
3664 /// the total time it took to run.
3665 ///
3666 /// Returns a [`RowCollectionIter`] that contains all of the response
3667 /// data.
3668 pub fn finish_incremental(
3669 &mut self,
3670 rows: RowCollection,
3671 max_result_size: u64,
3672 duration_histogram: &Histogram,
3673 ) -> Result<RowCollectionIter, String> {
3674 let now = Instant::now();
3675 let result = self.finish_incremental_inner(rows, max_result_size);
3676 let duration = now.elapsed();
3677 duration_histogram.observe(duration.as_secs_f64());
3678
3679 result
3680 }
3681
3682 fn finish_incremental_inner(
3683 &mut self,
3684 rows: RowCollection,
3685 max_result_size: u64,
3686 ) -> Result<RowCollectionIter, String> {
3687 // How much additional memory is required to make a sorted view.
3688 let sorted_view_mem = rows.entries().saturating_mul(std::mem::size_of::<usize>());
3689 let required_memory = rows.byte_len().saturating_add(sorted_view_mem);
3690
3691 // Bail if creating the sorted view would require us to use too much memory.
3692 if required_memory > usize::cast_from(max_result_size) {
3693 let max_bytes = ByteSize::b(max_result_size);
3694 return Err(format!("total result exceeds max size of {max_bytes}",));
3695 }
3696
3697 let batch_num_rows = rows.count();
3698
3699 let sorted_view = rows;
3700 let mut iter = sorted_view
3701 .into_row_iter()
3702 .apply_offset(self.remaining_offset)
3703 .with_projection(self.project.clone());
3704
3705 if let Some(limit) = self.remaining_limit {
3706 iter = iter.with_limit(limit);
3707 };
3708
3709 self.remaining_offset = self.remaining_offset.saturating_sub(batch_num_rows);
3710 if let Some(remaining_limit) = self.remaining_limit.as_mut() {
3711 *remaining_limit -= iter.count();
3712 }
3713
3714 // TODO(parkmycar): Re-think how we can calculate the total response size without
3715 // having to iterate through the entire collection of Rows, while still
3716 // respecting the LIMIT, OFFSET, and projections.
3717 //
3718 // Note: It feels a bit bad always calculating the response size, but we almost
3719 // always need it to either check the `max_returned_query_size`, or for reporting
3720 // in the query history.
3721 let response_size: usize = iter.clone().map(|row| row.data().len()).sum();
3722
3723 // Bail if we would end up returning more data to the client than they can support.
3724 if let Some(max) = self.remaining_max_returned_query_size {
3725 if response_size > usize::cast_from(max) {
3726 let max_bytes = ByteSize::b(self.max_returned_query_size.expect("known to exist"));
3727 return Err(format!("total result exceeds max size of {max_bytes}"));
3728 }
3729 }
3730
3731 Ok(iter)
3732 }
3733}
3734
3735/// Compare `left` and `right` using `order`. If that doesn't produce a strict
3736/// ordering, call `tiebreaker`.
3737pub fn compare_columns<F>(
3738 order: &[ColumnOrder],
3739 left: &[Datum],
3740 right: &[Datum],
3741 tiebreaker: F,
3742) -> Ordering
3743where
3744 F: Fn() -> Ordering,
3745{
3746 for order in order {
3747 let cmp = match (&left[order.column], &right[order.column]) {
3748 (Datum::Null, Datum::Null) => Ordering::Equal,
3749 (Datum::Null, _) => {
3750 if order.nulls_last {
3751 Ordering::Greater
3752 } else {
3753 Ordering::Less
3754 }
3755 }
3756 (_, Datum::Null) => {
3757 if order.nulls_last {
3758 Ordering::Less
3759 } else {
3760 Ordering::Greater
3761 }
3762 }
3763 (lval, rval) => {
3764 if order.desc {
3765 rval.cmp(lval)
3766 } else {
3767 lval.cmp(rval)
3768 }
3769 }
3770 };
3771 if cmp != Ordering::Equal {
3772 return cmp;
3773 }
3774 }
3775 tiebreaker()
3776}
3777
3778/// Describe a window frame, e.g. `RANGE UNBOUNDED PRECEDING` or
3779/// `ROWS BETWEEN 5 PRECEDING AND CURRENT ROW`.
3780///
3781/// Window frames define a subset of the partition , and only a subset of
3782/// window functions make use of the window frame.
3783#[derive(
3784 Debug,
3785 Clone,
3786 Eq,
3787 PartialEq,
3788 Ord,
3789 PartialOrd,
3790 Serialize,
3791 Deserialize,
3792 Hash,
3793 MzReflect
3794)]
3795pub struct WindowFrame {
3796 /// ROWS, RANGE or GROUPS
3797 pub units: WindowFrameUnits,
3798 /// Where the frame starts
3799 pub start_bound: WindowFrameBound,
3800 /// Where the frame ends
3801 pub end_bound: WindowFrameBound,
3802}
3803
3804impl Display for WindowFrame {
3805 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
3806 write!(
3807 f,
3808 "{} between {} and {}",
3809 self.units, self.start_bound, self.end_bound
3810 )
3811 }
3812}
3813
3814impl WindowFrame {
3815 /// Return the default window frame used when one is not explicitly defined
3816 pub fn default() -> Self {
3817 WindowFrame {
3818 units: WindowFrameUnits::Range,
3819 start_bound: WindowFrameBound::UnboundedPreceding,
3820 end_bound: WindowFrameBound::CurrentRow,
3821 }
3822 }
3823
3824 fn includes_current_row(&self) -> bool {
3825 use WindowFrameBound::*;
3826 match self.start_bound {
3827 UnboundedPreceding => match self.end_bound {
3828 UnboundedPreceding => false,
3829 OffsetPreceding(0) => true,
3830 OffsetPreceding(_) => false,
3831 CurrentRow => true,
3832 OffsetFollowing(_) => true,
3833 UnboundedFollowing => true,
3834 },
3835 OffsetPreceding(0) => match self.end_bound {
3836 UnboundedPreceding => unreachable!(),
3837 OffsetPreceding(0) => true,
3838 // Any nonzero offsets here will create an empty window
3839 OffsetPreceding(_) => false,
3840 CurrentRow => true,
3841 OffsetFollowing(_) => true,
3842 UnboundedFollowing => true,
3843 },
3844 OffsetPreceding(_) => match self.end_bound {
3845 UnboundedPreceding => unreachable!(),
3846 // Window ends at the current row
3847 OffsetPreceding(0) => true,
3848 OffsetPreceding(_) => false,
3849 CurrentRow => true,
3850 OffsetFollowing(_) => true,
3851 UnboundedFollowing => true,
3852 },
3853 CurrentRow => true,
3854 OffsetFollowing(0) => match self.end_bound {
3855 UnboundedPreceding => unreachable!(),
3856 OffsetPreceding(_) => unreachable!(),
3857 CurrentRow => unreachable!(),
3858 OffsetFollowing(_) => true,
3859 UnboundedFollowing => true,
3860 },
3861 OffsetFollowing(_) => match self.end_bound {
3862 UnboundedPreceding => unreachable!(),
3863 OffsetPreceding(_) => unreachable!(),
3864 CurrentRow => unreachable!(),
3865 OffsetFollowing(_) => false,
3866 UnboundedFollowing => false,
3867 },
3868 UnboundedFollowing => false,
3869 }
3870 }
3871}
3872
3873/// Describe how frame bounds are interpreted
3874#[derive(
3875 Debug,
3876 Clone,
3877 Eq,
3878 PartialEq,
3879 Ord,
3880 PartialOrd,
3881 Serialize,
3882 Deserialize,
3883 Hash,
3884 MzReflect
3885)]
3886pub enum WindowFrameUnits {
3887 /// Each row is treated as the unit of work for bounds
3888 Rows,
3889 /// Each peer group is treated as the unit of work for bounds,
3890 /// and offset-based bounds use the value of the ORDER BY expression
3891 Range,
3892 /// Each peer group is treated as the unit of work for bounds.
3893 /// Groups is currently not supported, and it is rejected during planning.
3894 Groups,
3895}
3896
3897impl Display for WindowFrameUnits {
3898 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
3899 match self {
3900 WindowFrameUnits::Rows => write!(f, "rows"),
3901 WindowFrameUnits::Range => write!(f, "range"),
3902 WindowFrameUnits::Groups => write!(f, "groups"),
3903 }
3904 }
3905}
3906
3907/// Specifies [WindowFrame]'s `start_bound` and `end_bound`
3908///
3909/// The order between frame bounds is significant, as Postgres enforces
3910/// some restrictions there.
3911#[derive(
3912 Debug,
3913 Clone,
3914 Serialize,
3915 Deserialize,
3916 PartialEq,
3917 Eq,
3918 Hash,
3919 MzReflect,
3920 PartialOrd,
3921 Ord
3922)]
3923pub enum WindowFrameBound {
3924 /// `UNBOUNDED PRECEDING`
3925 UnboundedPreceding,
3926 /// `<N> PRECEDING`
3927 OffsetPreceding(u64),
3928 /// `CURRENT ROW`
3929 CurrentRow,
3930 /// `<N> FOLLOWING`
3931 OffsetFollowing(u64),
3932 /// `UNBOUNDED FOLLOWING`.
3933 UnboundedFollowing,
3934}
3935
3936impl Display for WindowFrameBound {
3937 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
3938 match self {
3939 WindowFrameBound::UnboundedPreceding => write!(f, "unbounded preceding"),
3940 WindowFrameBound::OffsetPreceding(offset) => write!(f, "{} preceding", offset),
3941 WindowFrameBound::CurrentRow => write!(f, "current row"),
3942 WindowFrameBound::OffsetFollowing(offset) => write!(f, "{} following", offset),
3943 WindowFrameBound::UnboundedFollowing => write!(f, "unbounded following"),
3944 }
3945 }
3946}
3947
3948/// Maximum iterations for a LetRec.
3949#[derive(
3950 Debug,
3951 Clone,
3952 Copy,
3953 PartialEq,
3954 Eq,
3955 PartialOrd,
3956 Ord,
3957 Hash,
3958 Serialize,
3959 Deserialize
3960)]
3961pub struct LetRecLimit {
3962 /// Maximum number of iterations to evaluate.
3963 pub max_iters: NonZeroU64,
3964 /// Whether to throw an error when reaching the above limit.
3965 /// If true, we simply use the current contents of each Id as the final result.
3966 pub return_at_limit: bool,
3967}
3968
3969impl LetRecLimit {
3970 /// Compute the smallest limit from a Vec of `LetRecLimit`s.
3971 pub fn min_max_iter(limits: &Vec<Option<LetRecLimit>>) -> Option<u64> {
3972 limits
3973 .iter()
3974 .filter_map(|l| l.as_ref().map(|l| l.max_iters.get()))
3975 .min()
3976 }
3977
3978 /// The default value of `LetRecLimit::return_at_limit` when using the RECURSION LIMIT option of
3979 /// WMR without ERROR AT or RETURN AT.
3980 pub const RETURN_AT_LIMIT_DEFAULT: bool = false;
3981}
3982
3983impl Display for LetRecLimit {
3984 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
3985 write!(f, "[recursion_limit={}", self.max_iters)?;
3986 if self.return_at_limit != LetRecLimit::RETURN_AT_LIMIT_DEFAULT {
3987 write!(f, ", return_at_limit")?;
3988 }
3989 write!(f, "]")
3990 }
3991}
3992
3993/// For a global Get, this indicates whether we are going to read from Persist or from an index.
3994/// (See comment in MirRelationExpr::Get.)
3995#[derive(
3996 Clone,
3997 Debug,
3998 Eq,
3999 PartialEq,
4000 Ord,
4001 PartialOrd,
4002 Serialize,
4003 Deserialize,
4004 Hash
4005)]
4006pub enum AccessStrategy {
4007 /// It's either a local Get (a CTE), or unknown at the time.
4008 /// `prune_and_annotate_dataflow_index_imports` decides it for global Gets, and thus switches to
4009 /// one of the other variants.
4010 UnknownOrLocal,
4011 /// The Get will read from Persist.
4012 Persist,
4013 /// The Get will read from an index or indexes: (index id, how the index will be used).
4014 Index(Vec<(GlobalId, IndexUsageType)>),
4015 /// The Get will read a collection that is computed by the same dataflow, but in a different
4016 /// `BuildDesc` in `objects_to_build`.
4017 SameDataflow,
4018}
4019
4020#[cfg(test)]
4021mod tests {
4022 use mz_repr::explain::text::text_string_at;
4023
4024 use crate::explain::HumanizedExplain;
4025
4026 use super::*;
4027
4028 #[mz_ore::test]
4029 fn test_row_set_finishing_as_text() {
4030 let finishing = RowSetFinishing {
4031 order_by: vec![ColumnOrder {
4032 column: 4,
4033 desc: true,
4034 nulls_last: true,
4035 }],
4036 limit: Some(NonNeg::try_from(7).unwrap()),
4037 offset: Default::default(),
4038 project: vec![1, 3, 4, 5],
4039 };
4040
4041 let mode = HumanizedExplain::new(false);
4042 let expr = mode.expr(&finishing, None);
4043
4044 let act = text_string_at(&expr, mz_ore::str::Indent::default);
4045
4046 let exp = {
4047 use mz_ore::fmt::FormatBuffer;
4048 let mut s = String::new();
4049 write!(&mut s, "Finish");
4050 write!(&mut s, " order_by=[#4 desc nulls_last]");
4051 write!(&mut s, " limit=7");
4052 write!(&mut s, " output=[#1, #3..=#5]");
4053 writeln!(&mut s, "");
4054 s
4055 };
4056
4057 assert_eq!(act, exp);
4058 }
4059}
4060
4061/// An iterator over AST structures, which calls out nodes in difference.
4062///
4063/// The iterators visit two ASTs in tandem, continuing as long as the AST node data matches,
4064/// and yielding an output pair as soon as the AST nodes do not match. Their intent is to call
4065/// attention to the moments in the ASTs where they differ, and incidentally a stack-free way
4066/// to compare two ASTs.
4067mod structured_diff {
4068
4069 use super::MirRelationExpr;
4070 use itertools::Itertools;
4071
4072 /// An iterator over structured differences between two `MirRelationExpr` instances.
4073 pub struct MreDiff<'a> {
4074 /// Pairs of expressions that must still be compared.
4075 todo: Vec<(&'a MirRelationExpr, &'a MirRelationExpr)>,
4076 }
4077
4078 impl<'a> MreDiff<'a> {
4079 /// Create a new `MirRelationExpr` structured difference.
4080 pub fn new(expr1: &'a MirRelationExpr, expr2: &'a MirRelationExpr) -> Self {
4081 MreDiff {
4082 todo: vec![(expr1, expr2)],
4083 }
4084 }
4085 }
4086
4087 impl<'a> Iterator for MreDiff<'a> {
4088 // Pairs of expressions that do not match.
4089 type Item = (&'a MirRelationExpr, &'a MirRelationExpr);
4090
4091 fn next(&mut self) -> Option<Self::Item> {
4092 while let Some((expr1, expr2)) = self.todo.pop() {
4093 match (expr1, expr2) {
4094 (
4095 MirRelationExpr::Constant {
4096 rows: rows1,
4097 typ: typ1,
4098 },
4099 MirRelationExpr::Constant {
4100 rows: rows2,
4101 typ: typ2,
4102 },
4103 ) => {
4104 if rows1 != rows2 || typ1 != typ2 {
4105 return Some((expr1, expr2));
4106 }
4107 }
4108 (
4109 MirRelationExpr::Get {
4110 id: id1,
4111 typ: typ1,
4112 access_strategy: as1,
4113 },
4114 MirRelationExpr::Get {
4115 id: id2,
4116 typ: typ2,
4117 access_strategy: as2,
4118 },
4119 ) => {
4120 if id1 != id2 || typ1 != typ2 || as1 != as2 {
4121 return Some((expr1, expr2));
4122 }
4123 }
4124 (
4125 MirRelationExpr::Let {
4126 id: id1,
4127 body: body1,
4128 value: value1,
4129 },
4130 MirRelationExpr::Let {
4131 id: id2,
4132 body: body2,
4133 value: value2,
4134 },
4135 ) => {
4136 if id1 != id2 {
4137 return Some((expr1, expr2));
4138 } else {
4139 self.todo.push((body1, body2));
4140 self.todo.push((value1, value2));
4141 }
4142 }
4143 (
4144 MirRelationExpr::LetRec {
4145 ids: ids1,
4146 body: body1,
4147 values: values1,
4148 limits: limits1,
4149 },
4150 MirRelationExpr::LetRec {
4151 ids: ids2,
4152 body: body2,
4153 values: values2,
4154 limits: limits2,
4155 },
4156 ) => {
4157 if ids1 != ids2 || values1.len() != values2.len() || limits1 != limits2 {
4158 return Some((expr1, expr2));
4159 } else {
4160 self.todo.push((body1, body2));
4161 self.todo.extend(values1.iter().zip_eq(values2.iter()));
4162 }
4163 }
4164 (
4165 MirRelationExpr::Project {
4166 outputs: outputs1,
4167 input: input1,
4168 },
4169 MirRelationExpr::Project {
4170 outputs: outputs2,
4171 input: input2,
4172 },
4173 ) => {
4174 if outputs1 != outputs2 {
4175 return Some((expr1, expr2));
4176 } else {
4177 self.todo.push((input1, input2));
4178 }
4179 }
4180 (
4181 MirRelationExpr::Map {
4182 scalars: scalars1,
4183 input: input1,
4184 },
4185 MirRelationExpr::Map {
4186 scalars: scalars2,
4187 input: input2,
4188 },
4189 ) => {
4190 if scalars1 != scalars2 {
4191 return Some((expr1, expr2));
4192 } else {
4193 self.todo.push((input1, input2));
4194 }
4195 }
4196 (
4197 MirRelationExpr::Filter {
4198 predicates: predicates1,
4199 input: input1,
4200 },
4201 MirRelationExpr::Filter {
4202 predicates: predicates2,
4203 input: input2,
4204 },
4205 ) => {
4206 if predicates1 != predicates2 {
4207 return Some((expr1, expr2));
4208 } else {
4209 self.todo.push((input1, input2));
4210 }
4211 }
4212 (
4213 MirRelationExpr::FlatMap {
4214 input: input1,
4215 func: func1,
4216 exprs: exprs1,
4217 },
4218 MirRelationExpr::FlatMap {
4219 input: input2,
4220 func: func2,
4221 exprs: exprs2,
4222 },
4223 ) => {
4224 if func1 != func2 || exprs1 != exprs2 {
4225 return Some((expr1, expr2));
4226 } else {
4227 self.todo.push((input1, input2));
4228 }
4229 }
4230 (
4231 MirRelationExpr::Join {
4232 inputs: inputs1,
4233 equivalences: eq1,
4234 implementation: impl1,
4235 },
4236 MirRelationExpr::Join {
4237 inputs: inputs2,
4238 equivalences: eq2,
4239 implementation: impl2,
4240 },
4241 ) => {
4242 if inputs1.len() != inputs2.len() || eq1 != eq2 || impl1 != impl2 {
4243 return Some((expr1, expr2));
4244 } else {
4245 self.todo.extend(inputs1.iter().zip_eq(inputs2.iter()));
4246 }
4247 }
4248 (
4249 MirRelationExpr::Reduce {
4250 aggregates: aggregates1,
4251 input: inputs1,
4252 group_key: gk1,
4253 monotonic: m1,
4254 expected_group_size: egs1,
4255 },
4256 MirRelationExpr::Reduce {
4257 aggregates: aggregates2,
4258 input: inputs2,
4259 group_key: gk2,
4260 monotonic: m2,
4261 expected_group_size: egs2,
4262 },
4263 ) => {
4264 if aggregates1 != aggregates2 || gk1 != gk2 || m1 != m2 || egs1 != egs2 {
4265 return Some((expr1, expr2));
4266 } else {
4267 self.todo.push((inputs1, inputs2));
4268 }
4269 }
4270 (
4271 MirRelationExpr::TopK {
4272 group_key: gk1,
4273 order_key: order1,
4274 input: input1,
4275 limit: l1,
4276 offset: o1,
4277 monotonic: m1,
4278 expected_group_size: egs1,
4279 },
4280 MirRelationExpr::TopK {
4281 group_key: gk2,
4282 order_key: order2,
4283 input: input2,
4284 limit: l2,
4285 offset: o2,
4286 monotonic: m2,
4287 expected_group_size: egs2,
4288 },
4289 ) => {
4290 if order1 != order2
4291 || gk1 != gk2
4292 || l1 != l2
4293 || o1 != o2
4294 || m1 != m2
4295 || egs1 != egs2
4296 {
4297 return Some((expr1, expr2));
4298 } else {
4299 self.todo.push((input1, input2));
4300 }
4301 }
4302 (
4303 MirRelationExpr::Negate { input: input1 },
4304 MirRelationExpr::Negate { input: input2 },
4305 ) => {
4306 self.todo.push((input1, input2));
4307 }
4308 (
4309 MirRelationExpr::Threshold { input: input1 },
4310 MirRelationExpr::Threshold { input: input2 },
4311 ) => {
4312 self.todo.push((input1, input2));
4313 }
4314 (
4315 MirRelationExpr::Union {
4316 base: base1,
4317 inputs: inputs1,
4318 },
4319 MirRelationExpr::Union {
4320 base: base2,
4321 inputs: inputs2,
4322 },
4323 ) => {
4324 if inputs1.len() != inputs2.len() {
4325 return Some((expr1, expr2));
4326 } else {
4327 self.todo.push((base1, base2));
4328 self.todo.extend(inputs1.iter().zip_eq(inputs2.iter()));
4329 }
4330 }
4331 (
4332 MirRelationExpr::ArrangeBy {
4333 keys: keys1,
4334 input: input1,
4335 },
4336 MirRelationExpr::ArrangeBy {
4337 keys: keys2,
4338 input: input2,
4339 },
4340 ) => {
4341 if keys1 != keys2 {
4342 return Some((expr1, expr2));
4343 } else {
4344 self.todo.push((input1, input2));
4345 }
4346 }
4347 _ => {
4348 return Some((expr1, expr2));
4349 }
4350 }
4351 }
4352 None
4353 }
4354 }
4355}