mz_expr/relation.rs
1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10#![warn(missing_docs)]
11
12use std::cmp::{Ordering, max};
13use std::collections::{BTreeMap, BTreeSet};
14use std::fmt;
15use std::fmt::{Display, Formatter};
16use std::hash::{DefaultHasher, Hash, Hasher};
17use std::num::NonZeroU64;
18use std::time::Instant;
19
20use bytesize::ByteSize;
21use differential_dataflow::containers::{Columnation, CopyRegion};
22use itertools::Itertools;
23use mz_lowertest::MzReflect;
24use mz_ore::cast::CastFrom;
25use mz_ore::collections::CollectionExt;
26use mz_ore::id_gen::IdGen;
27use mz_ore::metrics::Histogram;
28use mz_ore::num::NonNeg;
29use mz_ore::soft_assert_no_log;
30use mz_ore::stack::RecursionLimitError;
31use mz_ore::str::Indent;
32use mz_repr::adt::numeric::NumericMaxScale;
33use mz_repr::explain::text::text_string_at;
34use mz_repr::explain::{
35 DummyHumanizer, ExplainConfig, ExprHumanizer, IndexUsageType, PlanRenderingContext,
36};
37use mz_repr::{
38 ColumnName, Datum, Diff, GlobalId, IntoRowIterator, ReprColumnType, Row, RowIterator,
39 SqlColumnType, SqlRelationType, SqlScalarType,
40};
41use serde::{Deserialize, Serialize};
42
43use crate::Id::Local;
44use crate::explain::{HumanizedExpr, HumanizerMode};
45use crate::relation::func::{AggregateFunc, LagLeadType, TableFunc};
46use crate::row::{RowCollection, SortedRowCollectionIter};
47use crate::visit::{Visit, VisitChildren};
48use crate::{
49 EvalError, FilterCharacteristics, Id, LocalId, MirScalarExpr, UnaryFunc, VariadicFunc,
50 func as scalar_func,
51};
52
53pub mod canonicalize;
54pub mod func;
55pub mod join_input_mapper;
56
57/// A recursion limit to be used for stack-safe traversals of [`MirRelationExpr`] trees.
58///
59/// The recursion limit must be large enough to accommodate for the linear representation
60/// of some pathological but frequently occurring query fragments.
61///
62/// For example, in MIR we could have long chains of
63/// - (1) `Let` bindings,
64/// - (2) `CallBinary` calls with associative functions such as `+`
65///
66/// Until we fix those, we need to stick with the larger recursion limit.
67pub const RECURSION_LIMIT: usize = 2048;
68
69/// A trait for types that describe how to build a collection.
70pub trait CollectionPlan {
71 /// Collects the set of global identifiers from dataflows referenced in Get.
72 fn depends_on_into(&self, out: &mut BTreeSet<GlobalId>);
73
74 /// Returns the set of global identifiers from dataflows referenced in Get.
75 ///
76 /// See [`CollectionPlan::depends_on_into`] to reuse an existing `BTreeSet`.
77 fn depends_on(&self) -> BTreeSet<GlobalId> {
78 let mut out = BTreeSet::new();
79 self.depends_on_into(&mut out);
80 out
81 }
82}
83
84/// An abstract syntax tree which defines a collection.
85///
86/// The AST is meant to reflect the capabilities of the `differential_dataflow::Collection` type,
87/// written generically enough to avoid run-time compilation work.
88///
89/// `derived_hash_with_manual_eq` was complaining for the wrong reason: This lint exists because
90/// it's bad when `Eq` doesn't agree with `Hash`, which is often quite likely if one of them is
91/// implemented manually. However, our manual implementation of `Eq` _will_ agree with the derived
92/// one. This is because the reason for the manual implementation is not to change the semantics
93/// from the derived one, but to avoid stack overflows.
94#[allow(clippy::derived_hash_with_manual_eq)]
95#[derive(Clone, Debug, Ord, PartialOrd, Serialize, Deserialize, MzReflect, Hash)]
96pub enum MirRelationExpr {
97 /// A constant relation containing specified rows.
98 ///
99 /// The runtime memory footprint of this operator is zero.
100 ///
101 /// When you would like to pattern match on this, consider using `MirRelationExpr::as_const`
102 /// instead, which looks behind `ArrangeBy`s. You might want this matching behavior because
103 /// constant folding doesn't remove `ArrangeBy`s.
104 Constant {
105 /// Rows of the constant collection and their multiplicities.
106 rows: Result<Vec<(Row, Diff)>, EvalError>,
107 /// Schema of the collection.
108 typ: SqlRelationType,
109 },
110 /// Get an existing dataflow.
111 ///
112 /// The runtime memory footprint of this operator is zero.
113 Get {
114 /// The identifier for the collection to load.
115 #[mzreflect(ignore)]
116 id: Id,
117 /// Schema of the collection.
118 typ: SqlRelationType,
119 /// If this is a global Get, this will indicate whether we are going to read from Persist or
120 /// from an index, or from a different object in `objects_to_build`. If it's an index, then
121 /// how downstream dataflow operations will use this index is also recorded. This is filled
122 /// by `prune_and_annotate_dataflow_index_imports`. Note that this is not used by the
123 /// lowering to LIR, but is used only by EXPLAIN.
124 #[mzreflect(ignore)]
125 access_strategy: AccessStrategy,
126 },
127 /// Introduce a temporary dataflow.
128 ///
129 /// The runtime memory footprint of this operator is zero.
130 Let {
131 /// The identifier to be used in `Get` variants to retrieve `value`.
132 #[mzreflect(ignore)]
133 id: LocalId,
134 /// The collection to be bound to `id`.
135 value: Box<MirRelationExpr>,
136 /// The result of the `Let`, evaluated with `id` bound to `value`.
137 body: Box<MirRelationExpr>,
138 },
139 /// Introduce mutually recursive bindings.
140 ///
141 /// Each `LocalId` is immediately bound to an initially empty collection
142 /// with the type of its corresponding `MirRelationExpr`. Repeatedly, each
143 /// binding is evaluated using the current contents of each other binding,
144 /// and is refreshed to contain the new evaluation. This process continues
145 /// through all bindings, and repeats as long as changes continue to occur.
146 ///
147 /// The resulting value of the expression is `body` evaluated once in the
148 /// context of the final iterates.
149 ///
150 /// A zero-binding instance can be replaced by `body`.
151 /// A single-binding instance is equivalent to `MirRelationExpr::Let`.
152 ///
153 /// The runtime memory footprint of this operator is zero.
154 LetRec {
155 /// The identifiers to be used in `Get` variants to retrieve each `value`.
156 #[mzreflect(ignore)]
157 ids: Vec<LocalId>,
158 /// The collections to be bound to each `id`.
159 values: Vec<MirRelationExpr>,
160 /// Maximum number of iterations, after which we should artificially force a fixpoint.
161 /// (Whether we error or just stop is configured by `LetRecLimit::return_at_limit`.)
162 /// The per-`LetRec` limit that the user specified is initially copied to each binding to
163 /// accommodate slicing and merging of `LetRec`s in MIR transforms (e.g., `NormalizeLets`).
164 #[mzreflect(ignore)]
165 limits: Vec<Option<LetRecLimit>>,
166 /// The result of the `Let`, evaluated with `id` bound to `value`.
167 body: Box<MirRelationExpr>,
168 },
169 /// Project out some columns from a dataflow
170 ///
171 /// The runtime memory footprint of this operator is zero.
172 Project {
173 /// The source collection.
174 input: Box<MirRelationExpr>,
175 /// Indices of columns to retain.
176 outputs: Vec<usize>,
177 },
178 /// Append new columns to a dataflow
179 ///
180 /// The runtime memory footprint of this operator is zero.
181 Map {
182 /// The source collection.
183 input: Box<MirRelationExpr>,
184 /// Expressions which determine values to append to each row.
185 /// An expression may refer to columns in `input` or
186 /// expressions defined earlier in the vector
187 scalars: Vec<MirScalarExpr>,
188 },
189 /// Like Map, but yields zero-or-more output rows per input row
190 ///
191 /// The runtime memory footprint of this operator is zero.
192 FlatMap {
193 /// The source collection
194 input: Box<MirRelationExpr>,
195 /// The table func to apply
196 func: TableFunc,
197 /// The argument to the table func
198 exprs: Vec<MirScalarExpr>,
199 },
200 /// Keep rows from a dataflow where all the predicates are true
201 ///
202 /// The runtime memory footprint of this operator is zero.
203 Filter {
204 /// The source collection.
205 input: Box<MirRelationExpr>,
206 /// Predicates, each of which must be true.
207 predicates: Vec<MirScalarExpr>,
208 },
209 /// Join several collections, where some columns must be equal.
210 ///
211 /// For further details consult the documentation for [`MirRelationExpr::join`].
212 ///
213 /// The runtime memory footprint of this operator can be proportional to
214 /// the sizes of all inputs and the size of all joins of prefixes.
215 /// This may be reduced due to arrangements available at rendering time.
216 Join {
217 /// A sequence of input relations.
218 inputs: Vec<MirRelationExpr>,
219 /// A sequence of equivalence classes of expressions on the cross product of inputs.
220 ///
221 /// Each equivalence class is a list of scalar expressions, where for each class the
222 /// intended interpretation is that all evaluated expressions should be equal.
223 ///
224 /// Each scalar expression is to be evaluated over the cross-product of all records
225 /// from all inputs. In many cases this may just be column selection from specific
226 /// inputs, but more general cases exist (e.g. complex functions of multiple columns
227 /// from multiple inputs, or just constant literals).
228 equivalences: Vec<Vec<MirScalarExpr>>,
229 /// Join implementation information.
230 #[serde(default)]
231 implementation: JoinImplementation,
232 },
233 /// Group a dataflow by some columns and aggregate over each group
234 ///
235 /// The runtime memory footprint of this operator is at most proportional to the
236 /// number of distinct records in the input and output. The actual requirements
237 /// can be less: the number of distinct inputs to each aggregate, summed across
238 /// each aggregate, plus the output size. For more details consult the code that
239 /// builds the associated dataflow.
240 Reduce {
241 /// The source collection.
242 input: Box<MirRelationExpr>,
243 /// Column indices used to form groups.
244 group_key: Vec<MirScalarExpr>,
245 /// Expressions which determine values to append to each row, after the group keys.
246 aggregates: Vec<AggregateExpr>,
247 /// True iff the input is known to monotonically increase (only addition of records).
248 #[serde(default)]
249 monotonic: bool,
250 /// User hint: expected number of values per group key. Used to optimize physical rendering.
251 #[serde(default)]
252 expected_group_size: Option<u64>,
253 },
254 /// Groups and orders within each group, limiting output.
255 ///
256 /// The runtime memory footprint of this operator is proportional to its input and output.
257 TopK {
258 /// The source collection.
259 input: Box<MirRelationExpr>,
260 /// Column indices used to form groups.
261 group_key: Vec<usize>,
262 /// Column indices used to order rows within groups.
263 order_key: Vec<ColumnOrder>,
264 /// Number of records to retain
265 #[serde(default)]
266 limit: Option<MirScalarExpr>,
267 /// Number of records to skip
268 #[serde(default)]
269 offset: usize,
270 /// True iff the input is known to monotonically increase (only addition of records).
271 #[serde(default)]
272 monotonic: bool,
273 /// User-supplied hint: how many rows will have the same group key.
274 #[serde(default)]
275 expected_group_size: Option<u64>,
276 },
277 /// Return a dataflow where the row counts are negated
278 ///
279 /// The runtime memory footprint of this operator is zero.
280 Negate {
281 /// The source collection.
282 input: Box<MirRelationExpr>,
283 },
284 /// Keep rows from a dataflow where the row counts are positive
285 ///
286 /// The runtime memory footprint of this operator is proportional to its input and output.
287 Threshold {
288 /// The source collection.
289 input: Box<MirRelationExpr>,
290 },
291 /// Adds the frequencies of elements in contained sets.
292 ///
293 /// The runtime memory footprint of this operator is zero.
294 Union {
295 /// A source collection.
296 base: Box<MirRelationExpr>,
297 /// Source collections to union.
298 inputs: Vec<MirRelationExpr>,
299 },
300 /// Technically a no-op. Used to render an index. Will be used to optimize queries
301 /// on finer grain. Each `keys` item represents a different index that should be
302 /// produced from the `keys`.
303 ///
304 /// The runtime memory footprint of this operator is proportional to its input.
305 ArrangeBy {
306 /// The source collection
307 input: Box<MirRelationExpr>,
308 /// Columns to arrange `input` by, in order of decreasing primacy
309 keys: Vec<Vec<MirScalarExpr>>,
310 },
311}
312
313impl PartialEq for MirRelationExpr {
314 fn eq(&self, other: &Self) -> bool {
315 // Capture the result and test it wrt `Ord` implementation in test environments.
316 let result = structured_diff::MreDiff::new(self, other).next().is_none();
317 mz_ore::soft_assert_eq_no_log!(result, self.cmp(other) == Ordering::Equal);
318 result
319 }
320}
321impl Eq for MirRelationExpr {}
322
323impl MirRelationExpr {
324 /// Reports the schema of the relation.
325 ///
326 /// This method determines the type through recursive traversal of the
327 /// relation expression, drawing from the types of base collections.
328 /// As such, this is not an especially cheap method, and should be used
329 /// judiciously.
330 ///
331 /// The relation type is computed incrementally with a recursive post-order
332 /// traversal, that accumulates the input types for the relations yet to be
333 /// visited in `type_stack`.
334 pub fn typ(&self) -> SqlRelationType {
335 let mut type_stack = Vec::new();
336 #[allow(deprecated)]
337 self.visit_pre_post_nolimit(
338 &mut |e: &MirRelationExpr| -> Option<Vec<&MirRelationExpr>> {
339 match &e {
340 MirRelationExpr::Let { body, .. } => {
341 // Do not traverse the value sub-graph, since it's not relevant for
342 // determining the relation type of Let operators.
343 Some(vec![&*body])
344 }
345 MirRelationExpr::LetRec { body, .. } => {
346 // Do not traverse the value sub-graph, since it's not relevant for
347 // determining the relation type of Let operators.
348 Some(vec![&*body])
349 }
350 _ => None,
351 }
352 },
353 &mut |e: &MirRelationExpr| {
354 match e {
355 MirRelationExpr::Let { .. } => {
356 let body_typ = type_stack.pop().unwrap();
357 // Insert a dummy relation type for the value, since `typ_with_input_types`
358 // won't look at it, but expects the relation type of the body to be second.
359 type_stack.push(SqlRelationType::empty());
360 type_stack.push(body_typ);
361 }
362 MirRelationExpr::LetRec { values, .. } => {
363 let body_typ = type_stack.pop().unwrap();
364 // Insert dummy relation types for the values, since `typ_with_input_types`
365 // won't look at them, but expects the relation type of the body to be last.
366 type_stack
367 .extend(std::iter::repeat(SqlRelationType::empty()).take(values.len()));
368 type_stack.push(body_typ);
369 }
370 _ => {}
371 }
372 let num_inputs = e.num_inputs();
373 let relation_type =
374 e.typ_with_input_types(&type_stack[type_stack.len() - num_inputs..]);
375 type_stack.truncate(type_stack.len() - num_inputs);
376 type_stack.push(relation_type);
377 },
378 );
379 assert_eq!(type_stack.len(), 1);
380 type_stack.pop().unwrap()
381 }
382
383 /// Reports the schema of the relation given the schema of the input relations.
384 ///
385 /// `input_types` is required to contain the schemas for the input relations of
386 /// the current relation in the same order as they are visited by `try_visit_children`
387 /// method, even though not all may be used for computing the schema of the
388 /// current relation. For example, `Let` expects two input types, one for the
389 /// value relation and one for the body, in that order, but only the one for the
390 /// body is used to determine the type of the `Let` relation.
391 ///
392 /// It is meant to be used during post-order traversals to compute relation
393 /// schemas incrementally.
394 pub fn typ_with_input_types(&self, input_types: &[SqlRelationType]) -> SqlRelationType {
395 let column_types = self.col_with_input_cols(input_types.iter().map(|i| &i.column_types));
396 let unique_keys = self.keys_with_input_keys(
397 input_types.iter().map(|i| i.arity()),
398 input_types.iter().map(|i| &i.keys),
399 );
400 SqlRelationType::new(column_types).with_keys(unique_keys)
401 }
402
403 /// Reports the column types of the relation given the column types of the
404 /// input relations.
405 ///
406 /// This method delegates to `try_col_with_input_cols`, panicking if an `Err`
407 /// variant is returned.
408 pub fn col_with_input_cols<'a, I>(&self, input_types: I) -> Vec<SqlColumnType>
409 where
410 I: Iterator<Item = &'a Vec<SqlColumnType>>,
411 {
412 match self.try_col_with_input_cols(input_types) {
413 Ok(col_types) => col_types,
414 Err(err) => panic!("{err}"),
415 }
416 }
417
418 /// Reports the column types of the relation given the column types of the input relations.
419 ///
420 /// `input_types` is required to contain the column types for the input relations of
421 /// the current relation in the same order as they are visited by `try_visit_children`
422 /// method, even though not all may be used for computing the schema of the
423 /// current relation. For example, `Let` expects two input types, one for the
424 /// value relation and one for the body, in that order, but only the one for the
425 /// body is used to determine the type of the `Let` relation.
426 ///
427 /// It is meant to be used during post-order traversals to compute column types
428 /// incrementally.
429 pub fn try_col_with_input_cols<'a, I>(
430 &self,
431 mut input_types: I,
432 ) -> Result<Vec<SqlColumnType>, String>
433 where
434 I: Iterator<Item = &'a Vec<SqlColumnType>>,
435 {
436 use MirRelationExpr::*;
437
438 let col_types = match self {
439 Constant { rows, typ } => {
440 let mut col_types = typ.column_types.clone();
441 let mut seen_null = vec![false; typ.arity()];
442 if let Ok(rows) = rows {
443 for (row, _diff) in rows {
444 for (datum, i) in row.iter().zip_eq(0..typ.arity()) {
445 if datum.is_null() {
446 seen_null[i] = true;
447 }
448 }
449 }
450 }
451 for (&seen_null, i) in seen_null.iter().zip_eq(0..typ.arity()) {
452 if !seen_null {
453 col_types[i].nullable = false;
454 } else {
455 assert!(col_types[i].nullable);
456 }
457 }
458 col_types
459 }
460 Get { typ, .. } => typ.column_types.clone(),
461 Project { outputs, .. } => {
462 let input = input_types.next().unwrap();
463 outputs.iter().map(|&i| input[i].clone()).collect()
464 }
465 Map { scalars, .. } => {
466 let mut result = input_types.next().unwrap().clone();
467 for scalar in scalars.iter() {
468 result.push(scalar.typ(&result))
469 }
470 result
471 }
472 FlatMap { func, .. } => {
473 let mut result = input_types.next().unwrap().clone();
474 result.extend(func.output_type().column_types);
475 result
476 }
477 Filter { predicates, .. } => {
478 let mut result = input_types.next().unwrap().clone();
479
480 // Set as nonnull any columns where null values would cause
481 // any predicate to evaluate to null.
482 for column in non_nullable_columns(predicates) {
483 result[column].nullable = false;
484 }
485 result
486 }
487 Join { equivalences, .. } => {
488 // Concatenate input column types
489 let mut types = input_types.flat_map(|cols| cols.to_owned()).collect_vec();
490 // In an equivalence class, if any column is non-null, then make all non-null
491 for equivalence in equivalences {
492 let col_inds = equivalence
493 .iter()
494 .filter_map(|expr| match expr {
495 MirScalarExpr::Column(col, _name) => Some(*col),
496 _ => None,
497 })
498 .collect_vec();
499 if col_inds.iter().any(|i| !types.get(*i).unwrap().nullable) {
500 for i in col_inds {
501 types.get_mut(i).unwrap().nullable = false;
502 }
503 }
504 }
505 types
506 }
507 Reduce {
508 group_key,
509 aggregates,
510 ..
511 } => {
512 let input = input_types.next().unwrap();
513 group_key
514 .iter()
515 .map(|e| e.typ(input))
516 .chain(aggregates.iter().map(|agg| agg.typ(input)))
517 .collect()
518 }
519 TopK { .. } | Negate { .. } | Threshold { .. } | ArrangeBy { .. } => {
520 input_types.next().unwrap().clone()
521 }
522 Let { .. } => {
523 // skip over the input types for `value`.
524 input_types.nth(1).unwrap().clone()
525 }
526 LetRec { values, .. } => {
527 // skip over the input types for `values`.
528 input_types.nth(values.len()).unwrap().clone()
529 }
530 Union { .. } => {
531 let mut result = input_types.next().unwrap().clone();
532 for input_col_types in input_types {
533 for (base_col, col) in result.iter_mut().zip_eq(input_col_types) {
534 *base_col = base_col
535 .union(col)
536 .map_err(|e| format!("{}\nin plan:\n{}", e, self.pretty()))?;
537 }
538 }
539 result
540 }
541 };
542
543 Ok(col_types)
544 }
545
546 /// Reports the column types of the relation given the column types of the
547 /// input relations.
548 ///
549 /// This method delegates to `try_repr_col_with_input_repr_cols`, panicking if an `Err`
550 /// variant is returned.
551 pub fn repr_col_with_input_repr_cols<'a, I>(&self, input_types: I) -> Vec<ReprColumnType>
552 where
553 I: Iterator<Item = &'a Vec<ReprColumnType>>,
554 {
555 match self.try_repr_col_with_input_repr_cols(input_types) {
556 Ok(col_types) => col_types,
557 Err(err) => panic!("{err}"),
558 }
559 }
560
561 /// Reports the column types of the relation given the column types of the input relations.
562 ///
563 /// `input_types` is required to contain the column types for the input relations of
564 /// the current relation in the same order as they are visited by `try_visit_children`
565 /// method, even though not all may be used for computing the schema of the
566 /// current relation. For example, `Let` expects two input types, one for the
567 /// value relation and one for the body, in that order, but only the one for the
568 /// body is used to determine the type of the `Let` relation.
569 ///
570 /// It is meant to be used during post-order traversals to compute column types
571 /// incrementally.
572 pub fn try_repr_col_with_input_repr_cols<'a, I>(
573 &self,
574 mut input_types: I,
575 ) -> Result<Vec<ReprColumnType>, String>
576 where
577 I: Iterator<Item = &'a Vec<ReprColumnType>>,
578 {
579 use MirRelationExpr::*;
580
581 let col_types = match self {
582 Constant { rows, typ } => {
583 let mut col_types = typ
584 .column_types
585 .iter()
586 .map(ReprColumnType::from)
587 .collect_vec();
588 let mut seen_null = vec![false; typ.arity()];
589 if let Ok(rows) = rows {
590 for (row, _diff) in rows {
591 for (datum, i) in row.iter().zip_eq(0..typ.arity()) {
592 if datum.is_null() {
593 seen_null[i] = true;
594 }
595 }
596 }
597 }
598 for (&seen_null, i) in seen_null.iter().zip_eq(0..typ.arity()) {
599 if !seen_null {
600 col_types[i].nullable = false;
601 } else {
602 assert!(col_types[i].nullable);
603 }
604 }
605 col_types
606 }
607 Get { typ, .. } => typ
608 .column_types
609 .iter()
610 .map(ReprColumnType::from)
611 .collect_vec(),
612 Project { outputs, .. } => {
613 let input = input_types.next().unwrap();
614 outputs.iter().map(|&i| input[i].clone()).collect()
615 }
616 Map { scalars, .. } => {
617 let mut result = input_types.next().unwrap().clone();
618 for scalar in scalars.iter() {
619 result.push(scalar.repr_typ(&result))
620 }
621 result
622 }
623 FlatMap { func, .. } => {
624 let mut result = input_types.next().unwrap().clone();
625 result.extend(
626 func.output_type()
627 .column_types
628 .iter()
629 .map(ReprColumnType::from),
630 );
631 result
632 }
633 Filter { predicates, .. } => {
634 let mut result = input_types.next().unwrap().clone();
635
636 // Set as nonnull any columns where null values would cause
637 // any predicate to evaluate to null.
638 for column in non_nullable_columns(predicates) {
639 result[column].nullable = false;
640 }
641 result
642 }
643 Join { equivalences, .. } => {
644 // Concatenate input column types
645 let mut types = input_types.flat_map(|cols| cols.to_owned()).collect_vec();
646 // In an equivalence class, if any column is non-null, then make all non-null
647 for equivalence in equivalences {
648 let col_inds = equivalence
649 .iter()
650 .filter_map(|expr| match expr {
651 MirScalarExpr::Column(col, _name) => Some(*col),
652 _ => None,
653 })
654 .collect_vec();
655 if col_inds.iter().any(|i| !types.get(*i).unwrap().nullable) {
656 for i in col_inds {
657 types.get_mut(i).unwrap().nullable = false;
658 }
659 }
660 }
661 types
662 }
663 Reduce {
664 group_key,
665 aggregates,
666 ..
667 } => {
668 let input = input_types.next().unwrap();
669 group_key
670 .iter()
671 .map(|e| e.repr_typ(input))
672 .chain(aggregates.iter().map(|agg| agg.repr_typ(input)))
673 .collect()
674 }
675 TopK { .. } | Negate { .. } | Threshold { .. } | ArrangeBy { .. } => {
676 input_types.next().unwrap().clone()
677 }
678 Let { .. } => {
679 // skip over the input types for `value`.
680 input_types.nth(1).unwrap().clone()
681 }
682 LetRec { values, .. } => {
683 // skip over the input types for `values`.
684 input_types.nth(values.len()).unwrap().clone()
685 }
686 Union { .. } => {
687 let mut result = input_types.next().unwrap().clone();
688 for input_col_types in input_types {
689 for (base_col, col) in result.iter_mut().zip_eq(input_col_types) {
690 *base_col = base_col
691 .union(col)
692 .map_err(|e| format!("{}\nin plan:\n{}", e, self.pretty()))?;
693 }
694 }
695 result
696 }
697 };
698
699 Ok(col_types)
700 }
701
702 /// Reports the unique keys of the relation given the arities and the unique
703 /// keys of the input relations.
704 ///
705 /// `input_arities` and `input_keys` are required to contain the
706 /// corresponding info for the input relations of
707 /// the current relation in the same order as they are visited by `try_visit_children`
708 /// method, even though not all may be used for computing the schema of the
709 /// current relation. For example, `Let` expects two input types, one for the
710 /// value relation and one for the body, in that order, but only the one for the
711 /// body is used to determine the type of the `Let` relation.
712 ///
713 /// It is meant to be used during post-order traversals to compute unique keys
714 /// incrementally.
715 pub fn keys_with_input_keys<'a, I, J>(
716 &self,
717 mut input_arities: I,
718 mut input_keys: J,
719 ) -> Vec<Vec<usize>>
720 where
721 I: Iterator<Item = usize>,
722 J: Iterator<Item = &'a Vec<Vec<usize>>>,
723 {
724 use MirRelationExpr::*;
725
726 let mut keys = match self {
727 Constant {
728 rows: Ok(rows),
729 typ,
730 } => {
731 let n_cols = typ.arity();
732 // If the `i`th entry is `Some`, then we have not yet observed non-uniqueness in the `i`th column.
733 let mut unique_values_per_col = vec![Some(BTreeSet::<Datum>::default()); n_cols];
734 for (row, diff) in rows {
735 for (i, datum) in row.iter().enumerate() {
736 if datum != Datum::Dummy {
737 if let Some(unique_vals) = &mut unique_values_per_col[i] {
738 let is_dupe = *diff != Diff::ONE || !unique_vals.insert(datum);
739 if is_dupe {
740 unique_values_per_col[i] = None;
741 }
742 }
743 }
744 }
745 }
746 if rows.len() == 0 || (rows.len() == 1 && rows[0].1 == Diff::ONE) {
747 vec![vec![]]
748 } else {
749 // XXX - Multi-column keys are not detected.
750 typ.keys
751 .iter()
752 .cloned()
753 .chain(
754 unique_values_per_col
755 .into_iter()
756 .enumerate()
757 .filter(|(_idx, unique_vals)| unique_vals.is_some())
758 .map(|(idx, _)| vec![idx]),
759 )
760 .collect()
761 }
762 }
763 Constant { rows: Err(_), typ } | Get { typ, .. } => typ.keys.clone(),
764 Threshold { .. } | ArrangeBy { .. } => input_keys.next().unwrap().clone(),
765 Let { .. } => {
766 // skip over the unique keys for value
767 input_keys.nth(1).unwrap().clone()
768 }
769 LetRec { values, .. } => {
770 // skip over the unique keys for value
771 input_keys.nth(values.len()).unwrap().clone()
772 }
773 Project { outputs, .. } => {
774 let input = input_keys.next().unwrap();
775 input
776 .iter()
777 .filter_map(|key_set| {
778 if key_set.iter().all(|k| outputs.contains(k)) {
779 Some(
780 key_set
781 .iter()
782 .map(|c| outputs.iter().position(|o| o == c).unwrap())
783 .collect(),
784 )
785 } else {
786 None
787 }
788 })
789 .collect()
790 }
791 Map { scalars, .. } => {
792 let mut remappings = Vec::new();
793 let arity = input_arities.next().unwrap();
794 for (column, scalar) in scalars.iter().enumerate() {
795 // assess whether the scalar preserves uniqueness,
796 // and could participate in a key!
797
798 fn uniqueness(expr: &MirScalarExpr) -> Option<usize> {
799 match expr {
800 MirScalarExpr::CallUnary { func, expr } => {
801 if func.preserves_uniqueness() {
802 uniqueness(expr)
803 } else {
804 None
805 }
806 }
807 MirScalarExpr::Column(c, _name) => Some(*c),
808 _ => None,
809 }
810 }
811
812 if let Some(c) = uniqueness(scalar) {
813 remappings.push((c, column + arity));
814 }
815 }
816
817 let mut result = input_keys.next().unwrap().clone();
818 let mut new_keys = Vec::new();
819 // Any column in `remappings` could be replaced in a key
820 // by the corresponding c. This could lead to combinatorial
821 // explosion using our current representation, so we wont
822 // do that. Instead, we'll handle the case of one remapping.
823 if remappings.len() == 1 {
824 let (old, new) = remappings.pop().unwrap();
825 for key in &result {
826 if key.contains(&old) {
827 let mut new_key: Vec<usize> =
828 key.iter().cloned().filter(|k| k != &old).collect();
829 new_key.push(new);
830 new_key.sort_unstable();
831 new_keys.push(new_key);
832 }
833 }
834 result.append(&mut new_keys);
835 }
836 result
837 }
838 FlatMap { .. } => {
839 // FlatMap can add duplicate rows, so input keys are no longer
840 // valid
841 vec![]
842 }
843 Negate { .. } => {
844 // Although negate may have distinct records for each key,
845 // the multiplicity is -1 rather than 1. This breaks many
846 // of the optimization uses of "keys".
847 vec![]
848 }
849 Filter { predicates, .. } => {
850 // A filter inherits the keys of its input unless the filters
851 // have reduced the input to a single row, in which case the
852 // keys of the input are `()`.
853 let mut input = input_keys.next().unwrap().clone();
854
855 if !input.is_empty() {
856 // Track columns equated to literals, which we can prune.
857 let mut cols_equal_to_literal = BTreeSet::new();
858
859 // Perform union find on `col1 = col2` to establish
860 // connected components of equated columns. Absent any
861 // equalities, this will be `0 .. #c` (where #c is the
862 // greatest column referenced by a predicate), but each
863 // equality will orient the root of the greater to the root
864 // of the lesser.
865 let mut union_find = Vec::new();
866
867 for expr in predicates.iter() {
868 if let MirScalarExpr::CallBinary {
869 func: crate::BinaryFunc::Eq(_),
870 expr1,
871 expr2,
872 } = expr
873 {
874 if let MirScalarExpr::Column(c, _name) = &**expr1 {
875 if expr2.is_literal_ok() {
876 cols_equal_to_literal.insert(c);
877 }
878 }
879 if let MirScalarExpr::Column(c, _name) = &**expr2 {
880 if expr1.is_literal_ok() {
881 cols_equal_to_literal.insert(c);
882 }
883 }
884 // Perform union-find to equate columns.
885 if let (Some(c1), Some(c2)) = (expr1.as_column(), expr2.as_column()) {
886 if c1 != c2 {
887 // Ensure union_find has entries up to
888 // max(c1, c2) by filling up missing
889 // positions with identity mappings.
890 while union_find.len() <= std::cmp::max(c1, c2) {
891 union_find.push(union_find.len());
892 }
893 let mut r1 = c1; // Find the representative column of [c1].
894 while r1 != union_find[r1] {
895 assert!(union_find[r1] < r1);
896 r1 = union_find[r1];
897 }
898 let mut r2 = c2; // Find the representative column of [c2].
899 while r2 != union_find[r2] {
900 assert!(union_find[r2] < r2);
901 r2 = union_find[r2];
902 }
903 // Union [c1] and [c2] by pointing the
904 // larger to the smaller representative (we
905 // update the remaining equivalence class
906 // members only once after this for-loop).
907 union_find[std::cmp::max(r1, r2)] = std::cmp::min(r1, r2);
908 }
909 }
910 }
911 }
912
913 // Complete union-find by pointing each element at its representative column.
914 for i in 0..union_find.len() {
915 // Iteration not required, as each prior already references the right column.
916 union_find[i] = union_find[union_find[i]];
917 }
918
919 // Remove columns bound to literals, and remap columns equated to earlier columns.
920 // We will re-expand remapped columns in a moment, but this avoids exponential work.
921 for key_set in &mut input {
922 key_set.retain(|k| !cols_equal_to_literal.contains(&k));
923 for col in key_set.iter_mut() {
924 if let Some(equiv) = union_find.get(*col) {
925 *col = *equiv;
926 }
927 }
928 key_set.sort();
929 key_set.dedup();
930 }
931 input.sort();
932 input.dedup();
933
934 // Expand out each key to each of its equivalent forms.
935 // Each instance of `col` can be replaced by any equivalent column.
936 // This has the potential to result in exponentially sized number of unique keys,
937 // and in the future we should probably maintain unique keys modulo equivalence.
938
939 // First, compute an inverse map from each representative
940 // column `sub` to all other equivalent columns `col`.
941 let mut subs = Vec::new();
942 for (col, sub) in union_find.iter().enumerate() {
943 if *sub != col {
944 assert!(*sub < col);
945 while subs.len() <= *sub {
946 subs.push(Vec::new());
947 }
948 subs[*sub].push(col);
949 }
950 }
951 // For each column, substitute for it in each occurrence.
952 let mut to_add = Vec::new();
953 for (col, subs) in subs.iter().enumerate() {
954 if !subs.is_empty() {
955 for key_set in input.iter() {
956 if key_set.contains(&col) {
957 let mut to_extend = key_set.clone();
958 to_extend.retain(|c| c != &col);
959 for sub in subs {
960 to_extend.push(*sub);
961 to_add.push(to_extend.clone());
962 to_extend.pop();
963 }
964 }
965 }
966 }
967 // No deduplication, as we cannot introduce duplicates.
968 input.append(&mut to_add);
969 }
970 for key_set in input.iter_mut() {
971 key_set.sort();
972 key_set.dedup();
973 }
974 }
975 input
976 }
977 Join { equivalences, .. } => {
978 // It is important the `new_from_input_arities` constructor is
979 // used. Otherwise, Materialize may potentially end up in an
980 // infinite loop.
981 let input_mapper = crate::JoinInputMapper::new_from_input_arities(input_arities);
982
983 input_mapper.global_keys(input_keys, equivalences)
984 }
985 Reduce { group_key, .. } => {
986 // The group key should form a key, but we might already have
987 // keys that are subsets of the group key, and should retain
988 // those instead, if so.
989 let mut result = Vec::new();
990 for key_set in input_keys.next().unwrap() {
991 if key_set
992 .iter()
993 .all(|k| group_key.contains(&MirScalarExpr::column(*k)))
994 {
995 result.push(
996 key_set
997 .iter()
998 .map(|i| {
999 group_key
1000 .iter()
1001 .position(|k| k == &MirScalarExpr::column(*i))
1002 .unwrap()
1003 })
1004 .collect::<Vec<_>>(),
1005 );
1006 }
1007 }
1008 if result.is_empty() {
1009 result.push((0..group_key.len()).collect());
1010 }
1011 result
1012 }
1013 TopK {
1014 group_key, limit, ..
1015 } => {
1016 // If `limit` is `Some(1)` then the group key will become
1017 // a unique key, as there will be only one record with that key.
1018 let mut result = input_keys.next().unwrap().clone();
1019 if limit.as_ref().and_then(|x| x.as_literal_int64()) == Some(1) {
1020 result.push(group_key.clone())
1021 }
1022 result
1023 }
1024 Union { base, inputs } => {
1025 // Generally, unions do not have any unique keys, because
1026 // each input might duplicate some. However, there is at
1027 // least one idiomatic structure that does preserve keys,
1028 // which results from SQL aggregations that must populate
1029 // absent records with default values. In that pattern,
1030 // the union of one GET with its negation, which has first
1031 // been subjected to a projection and map, we can remove
1032 // their influence on the key structure.
1033 //
1034 // If there are A, B, each with a unique `key` such that
1035 // we are looking at
1036 //
1037 // A.proj(set_containing_key) + (B - A.proj(key)).map(stuff)
1038 //
1039 // Then we can report `key` as a unique key.
1040 //
1041 // TODO: make unique key structure an optimization analysis
1042 // rather than part of the type information.
1043 // TODO: perhaps ensure that (above) A.proj(key) is a
1044 // subset of B, as otherwise there are negative records
1045 // and who knows what is true (not expected, but again
1046 // who knows what the query plan might look like).
1047
1048 let arity = input_arities.next().unwrap();
1049 let (base_projection, base_with_project_stripped) =
1050 if let MirRelationExpr::Project { input, outputs } = &**base {
1051 (outputs.clone(), &**input)
1052 } else {
1053 // A input without a project is equivalent to an input
1054 // with the project being all columns in the input in order.
1055 ((0..arity).collect::<Vec<_>>(), &**base)
1056 };
1057 let mut result = Vec::new();
1058 if let MirRelationExpr::Get {
1059 id: first_id,
1060 typ: _,
1061 ..
1062 } = base_with_project_stripped
1063 {
1064 if inputs.len() == 1 {
1065 if let MirRelationExpr::Map { input, .. } = &inputs[0] {
1066 if let MirRelationExpr::Union { base, inputs } = &**input {
1067 if inputs.len() == 1 {
1068 if let Some((input, outputs)) = base.is_negated_project() {
1069 if let MirRelationExpr::Get {
1070 id: second_id,
1071 typ: _,
1072 ..
1073 } = input
1074 {
1075 if first_id == second_id {
1076 result.extend(
1077 input_keys
1078 .next()
1079 .unwrap()
1080 .into_iter()
1081 .filter(|key| {
1082 key.iter().all(|c| {
1083 outputs.get(*c) == Some(c)
1084 && base_projection.get(*c)
1085 == Some(c)
1086 })
1087 })
1088 .cloned(),
1089 );
1090 }
1091 }
1092 }
1093 }
1094 }
1095 }
1096 }
1097 }
1098 // Important: do not inherit keys of either input, as not unique.
1099 result
1100 }
1101 };
1102 keys.sort();
1103 keys.dedup();
1104 keys
1105 }
1106
1107 /// The number of columns in the relation.
1108 ///
1109 /// This number is determined from the type, which is determined recursively
1110 /// at non-trivial cost.
1111 ///
1112 /// The arity is computed incrementally with a recursive post-order
1113 /// traversal, that accumulates the arities for the relations yet to be
1114 /// visited in `arity_stack`.
1115 pub fn arity(&self) -> usize {
1116 let mut arity_stack = Vec::new();
1117 #[allow(deprecated)]
1118 self.visit_pre_post_nolimit(
1119 &mut |e: &MirRelationExpr| -> Option<Vec<&MirRelationExpr>> {
1120 match &e {
1121 MirRelationExpr::Let { body, .. } => {
1122 // Do not traverse the value sub-graph, since it's not relevant for
1123 // determining the arity of Let operators.
1124 Some(vec![&*body])
1125 }
1126 MirRelationExpr::LetRec { body, .. } => {
1127 // Do not traverse the value sub-graph, since it's not relevant for
1128 // determining the arity of Let operators.
1129 Some(vec![&*body])
1130 }
1131 MirRelationExpr::Project { .. } | MirRelationExpr::Reduce { .. } => {
1132 // No further traversal is required; these operators know their arity.
1133 Some(Vec::new())
1134 }
1135 _ => None,
1136 }
1137 },
1138 &mut |e: &MirRelationExpr| {
1139 match &e {
1140 MirRelationExpr::Let { .. } => {
1141 let body_arity = arity_stack.pop().unwrap();
1142 arity_stack.push(0);
1143 arity_stack.push(body_arity);
1144 }
1145 MirRelationExpr::LetRec { values, .. } => {
1146 let body_arity = arity_stack.pop().unwrap();
1147 arity_stack.extend(std::iter::repeat(0).take(values.len()));
1148 arity_stack.push(body_arity);
1149 }
1150 MirRelationExpr::Project { .. } | MirRelationExpr::Reduce { .. } => {
1151 arity_stack.push(0);
1152 }
1153 _ => {}
1154 }
1155 let num_inputs = e.num_inputs();
1156 let input_arities = arity_stack.drain(arity_stack.len() - num_inputs..);
1157 let arity = e.arity_with_input_arities(input_arities);
1158 arity_stack.push(arity);
1159 },
1160 );
1161 assert_eq!(arity_stack.len(), 1);
1162 arity_stack.pop().unwrap()
1163 }
1164
1165 /// Reports the arity of the relation given the schema of the input relations.
1166 ///
1167 /// `input_arities` is required to contain the arities for the input relations of
1168 /// the current relation in the same order as they are visited by `try_visit_children`
1169 /// method, even though not all may be used for computing the schema of the
1170 /// current relation. For example, `Let` expects two input types, one for the
1171 /// value relation and one for the body, in that order, but only the one for the
1172 /// body is used to determine the type of the `Let` relation.
1173 ///
1174 /// It is meant to be used during post-order traversals to compute arities
1175 /// incrementally.
1176 pub fn arity_with_input_arities<I>(&self, mut input_arities: I) -> usize
1177 where
1178 I: Iterator<Item = usize>,
1179 {
1180 use MirRelationExpr::*;
1181
1182 match self {
1183 Constant { rows: _, typ } => typ.arity(),
1184 Get { typ, .. } => typ.arity(),
1185 Let { .. } => {
1186 input_arities.next();
1187 input_arities.next().unwrap()
1188 }
1189 LetRec { values, .. } => {
1190 for _ in 0..values.len() {
1191 input_arities.next();
1192 }
1193 input_arities.next().unwrap()
1194 }
1195 Project { outputs, .. } => outputs.len(),
1196 Map { scalars, .. } => input_arities.next().unwrap() + scalars.len(),
1197 FlatMap { func, .. } => input_arities.next().unwrap() + func.output_arity(),
1198 Join { .. } => input_arities.sum(),
1199 Reduce {
1200 input: _,
1201 group_key,
1202 aggregates,
1203 ..
1204 } => group_key.len() + aggregates.len(),
1205 Filter { .. }
1206 | TopK { .. }
1207 | Negate { .. }
1208 | Threshold { .. }
1209 | Union { .. }
1210 | ArrangeBy { .. } => input_arities.next().unwrap(),
1211 }
1212 }
1213
1214 /// The number of child relations this relation has.
1215 pub fn num_inputs(&self) -> usize {
1216 let mut count = 0;
1217
1218 self.visit_children(|_| count += 1);
1219
1220 count
1221 }
1222
1223 /// Constructs a constant collection from specific rows and schema, where
1224 /// each row will have a multiplicity of one.
1225 pub fn constant(rows: Vec<Vec<Datum>>, typ: SqlRelationType) -> Self {
1226 let rows = rows.into_iter().map(|row| (row, Diff::ONE)).collect();
1227 MirRelationExpr::constant_diff(rows, typ)
1228 }
1229
1230 /// Constructs a constant collection from specific rows and schema, where
1231 /// each row can have an arbitrary multiplicity.
1232 pub fn constant_diff(rows: Vec<(Vec<Datum>, Diff)>, typ: SqlRelationType) -> Self {
1233 for (row, _diff) in &rows {
1234 for (datum, column_typ) in row.iter().zip_eq(typ.column_types.iter()) {
1235 assert!(
1236 datum.is_instance_of_sql(column_typ),
1237 "Expected datum of type {:?}, got value {:?}",
1238 column_typ,
1239 datum
1240 );
1241 }
1242 }
1243 let rows = Ok(rows
1244 .into_iter()
1245 .map(move |(row, diff)| (Row::pack_slice(&row), diff))
1246 .collect());
1247 MirRelationExpr::Constant { rows, typ }
1248 }
1249
1250 /// If self is a constant, return the value and the type, otherwise `None`.
1251 /// Looks behind `ArrangeBy`s.
1252 pub fn as_const(&self) -> Option<(&Result<Vec<(Row, Diff)>, EvalError>, &SqlRelationType)> {
1253 match self {
1254 MirRelationExpr::Constant { rows, typ } => Some((rows, typ)),
1255 MirRelationExpr::ArrangeBy { input, .. } => input.as_const(),
1256 _ => None,
1257 }
1258 }
1259
1260 /// If self is a constant, mutably return the value and the type, otherwise `None`.
1261 /// Looks behind `ArrangeBy`s.
1262 pub fn as_const_mut(
1263 &mut self,
1264 ) -> Option<(
1265 &mut Result<Vec<(Row, Diff)>, EvalError>,
1266 &mut SqlRelationType,
1267 )> {
1268 match self {
1269 MirRelationExpr::Constant { rows, typ } => Some((rows, typ)),
1270 MirRelationExpr::ArrangeBy { input, .. } => input.as_const_mut(),
1271 _ => None,
1272 }
1273 }
1274
1275 /// If self is a constant error, return the error, otherwise `None`.
1276 /// Looks behind `ArrangeBy`s.
1277 pub fn as_const_err(&self) -> Option<&EvalError> {
1278 match self {
1279 MirRelationExpr::Constant { rows: Err(e), .. } => Some(e),
1280 MirRelationExpr::ArrangeBy { input, .. } => input.as_const_err(),
1281 _ => None,
1282 }
1283 }
1284
1285 /// Checks if `self` is the single element collection with no columns.
1286 pub fn is_constant_singleton(&self) -> bool {
1287 if let Some((Ok(rows), typ)) = self.as_const() {
1288 rows.len() == 1 && typ.column_types.len() == 0 && rows[0].1 == Diff::ONE
1289 } else {
1290 false
1291 }
1292 }
1293
1294 /// Constructs the expression for getting a local collection.
1295 pub fn local_get(id: LocalId, typ: SqlRelationType) -> Self {
1296 MirRelationExpr::Get {
1297 id: Id::Local(id),
1298 typ,
1299 access_strategy: AccessStrategy::UnknownOrLocal,
1300 }
1301 }
1302
1303 /// Constructs the expression for getting a global collection
1304 pub fn global_get(id: GlobalId, typ: SqlRelationType) -> Self {
1305 MirRelationExpr::Get {
1306 id: Id::Global(id),
1307 typ,
1308 access_strategy: AccessStrategy::UnknownOrLocal,
1309 }
1310 }
1311
1312 /// Retains only the columns specified by `output`.
1313 pub fn project(mut self, mut outputs: Vec<usize>) -> Self {
1314 if let MirRelationExpr::Project {
1315 outputs: columns, ..
1316 } = &mut self
1317 {
1318 // Update `outputs` to reference base columns of `input`.
1319 for column in outputs.iter_mut() {
1320 *column = columns[*column];
1321 }
1322 *columns = outputs;
1323 self
1324 } else {
1325 MirRelationExpr::Project {
1326 input: Box::new(self),
1327 outputs,
1328 }
1329 }
1330 }
1331
1332 /// Append to each row the results of applying elements of `scalar`.
1333 pub fn map(mut self, scalars: Vec<MirScalarExpr>) -> Self {
1334 if let MirRelationExpr::Map { scalars: s, .. } = &mut self {
1335 s.extend(scalars);
1336 self
1337 } else if !scalars.is_empty() {
1338 MirRelationExpr::Map {
1339 input: Box::new(self),
1340 scalars,
1341 }
1342 } else {
1343 self
1344 }
1345 }
1346
1347 /// Append to each row a single `scalar`.
1348 pub fn map_one(self, scalar: MirScalarExpr) -> Self {
1349 self.map(vec![scalar])
1350 }
1351
1352 /// Like `map`, but yields zero-or-more output rows per input row
1353 pub fn flat_map(self, func: TableFunc, exprs: Vec<MirScalarExpr>) -> Self {
1354 MirRelationExpr::FlatMap {
1355 input: Box::new(self),
1356 func,
1357 exprs,
1358 }
1359 }
1360
1361 /// Retain only the rows satisfying each of several predicates.
1362 pub fn filter<I>(mut self, predicates: I) -> Self
1363 where
1364 I: IntoIterator<Item = MirScalarExpr>,
1365 {
1366 // Extract existing predicates
1367 let mut new_predicates = if let MirRelationExpr::Filter { input, predicates } = self {
1368 self = *input;
1369 predicates
1370 } else {
1371 Vec::new()
1372 };
1373 // Normalize collection of predicates.
1374 new_predicates.extend(predicates);
1375 new_predicates.retain(|p| !p.is_literal_true());
1376 new_predicates.sort();
1377 new_predicates.dedup();
1378 // Introduce a `Filter` only if we have predicates.
1379 if !new_predicates.is_empty() {
1380 self = MirRelationExpr::Filter {
1381 input: Box::new(self),
1382 predicates: new_predicates,
1383 };
1384 }
1385
1386 self
1387 }
1388
1389 /// Form the Cartesian outer-product of rows in both inputs.
1390 pub fn product(mut self, right: Self) -> Self {
1391 if right.is_constant_singleton() {
1392 self
1393 } else if self.is_constant_singleton() {
1394 right
1395 } else if let MirRelationExpr::Join { inputs, .. } = &mut self {
1396 inputs.push(right);
1397 self
1398 } else {
1399 MirRelationExpr::join(vec![self, right], vec![])
1400 }
1401 }
1402
1403 /// Performs a relational equijoin among the input collections.
1404 ///
1405 /// The sequence `inputs` each describe different input collections, and the sequence `variables` describes
1406 /// equality constraints that some of their columns must satisfy. Each element in `variable` describes a set
1407 /// of pairs `(input_index, column_index)` where every value described by that set must be equal.
1408 ///
1409 /// For example, the pair `(input, column)` indexes into `inputs[input][column]`, extracting the `input`th
1410 /// input collection and for each row examining its `column`th column.
1411 ///
1412 /// # Example
1413 ///
1414 /// ```rust
1415 /// use mz_repr::{Datum, SqlColumnType, SqlRelationType, SqlScalarType};
1416 /// use mz_expr::MirRelationExpr;
1417 ///
1418 /// // A common schema for each input.
1419 /// let schema = SqlRelationType::new(vec![
1420 /// SqlScalarType::Int32.nullable(false),
1421 /// SqlScalarType::Int32.nullable(false),
1422 /// ]);
1423 ///
1424 /// // the specific data are not important here.
1425 /// let data = vec![Datum::Int32(0), Datum::Int32(1)];
1426 ///
1427 /// // Three collections that could have been different.
1428 /// let input0 = MirRelationExpr::constant(vec![data.clone()], schema.clone());
1429 /// let input1 = MirRelationExpr::constant(vec![data.clone()], schema.clone());
1430 /// let input2 = MirRelationExpr::constant(vec![data.clone()], schema.clone());
1431 ///
1432 /// // Join the three relations looking for triangles, like so.
1433 /// //
1434 /// // Output(A,B,C) := Input0(A,B), Input1(B,C), Input2(A,C)
1435 /// let joined = MirRelationExpr::join(
1436 /// vec![input0, input1, input2],
1437 /// vec![
1438 /// vec![(0,0), (2,0)], // fields A of inputs 0 and 2.
1439 /// vec![(0,1), (1,0)], // fields B of inputs 0 and 1.
1440 /// vec![(1,1), (2,1)], // fields C of inputs 1 and 2.
1441 /// ],
1442 /// );
1443 ///
1444 /// // Technically the above produces `Output(A,B,B,C,A,C)` because the columns are concatenated.
1445 /// // A projection resolves this and produces the correct output.
1446 /// let result = joined.project(vec![0, 1, 3]);
1447 /// ```
1448 pub fn join(inputs: Vec<MirRelationExpr>, variables: Vec<Vec<(usize, usize)>>) -> Self {
1449 let input_mapper = join_input_mapper::JoinInputMapper::new(&inputs);
1450
1451 let equivalences = variables
1452 .into_iter()
1453 .map(|vs| {
1454 vs.into_iter()
1455 .map(|(r, c)| input_mapper.map_expr_to_global(MirScalarExpr::column(c), r))
1456 .collect::<Vec<_>>()
1457 })
1458 .collect::<Vec<_>>();
1459
1460 Self::join_scalars(inputs, equivalences)
1461 }
1462
1463 /// Constructs a join operator from inputs and required-equal scalar expressions.
1464 pub fn join_scalars(
1465 mut inputs: Vec<MirRelationExpr>,
1466 equivalences: Vec<Vec<MirScalarExpr>>,
1467 ) -> Self {
1468 // Remove all constant inputs that are the identity for join.
1469 // They neither introduce nor modify any column references.
1470 inputs.retain(|i| !i.is_constant_singleton());
1471 MirRelationExpr::Join {
1472 inputs,
1473 equivalences,
1474 implementation: JoinImplementation::Unimplemented,
1475 }
1476 }
1477
1478 /// Perform a key-wise reduction / aggregation.
1479 ///
1480 /// The `group_key` argument indicates columns in the input collection that should
1481 /// be grouped, and `aggregates` lists aggregation functions each of which produces
1482 /// one output column in addition to the keys.
1483 pub fn reduce(
1484 self,
1485 group_key: Vec<usize>,
1486 aggregates: Vec<AggregateExpr>,
1487 expected_group_size: Option<u64>,
1488 ) -> Self {
1489 MirRelationExpr::Reduce {
1490 input: Box::new(self),
1491 group_key: group_key.into_iter().map(MirScalarExpr::column).collect(),
1492 aggregates,
1493 monotonic: false,
1494 expected_group_size,
1495 }
1496 }
1497
1498 /// Perform a key-wise reduction order by and limit.
1499 ///
1500 /// The `group_key` argument indicates columns in the input collection that should
1501 /// be grouped, the `order_key` argument indicates columns that should be further
1502 /// used to order records within groups, and the `limit` argument constrains the
1503 /// total number of records that should be produced in each group.
1504 pub fn top_k(
1505 self,
1506 group_key: Vec<usize>,
1507 order_key: Vec<ColumnOrder>,
1508 limit: Option<MirScalarExpr>,
1509 offset: usize,
1510 expected_group_size: Option<u64>,
1511 ) -> Self {
1512 MirRelationExpr::TopK {
1513 input: Box::new(self),
1514 group_key,
1515 order_key,
1516 limit,
1517 offset,
1518 expected_group_size,
1519 monotonic: false,
1520 }
1521 }
1522
1523 /// Negates the occurrences of each row.
1524 pub fn negate(self) -> Self {
1525 if let MirRelationExpr::Negate { input } = self {
1526 *input
1527 } else {
1528 MirRelationExpr::Negate {
1529 input: Box::new(self),
1530 }
1531 }
1532 }
1533
1534 /// Removes all but the first occurrence of each row.
1535 pub fn distinct(self) -> Self {
1536 let arity = self.arity();
1537 self.distinct_by((0..arity).collect())
1538 }
1539
1540 /// Removes all but the first occurrence of each key. Columns not included
1541 /// in the `group_key` are discarded.
1542 pub fn distinct_by(self, group_key: Vec<usize>) -> Self {
1543 self.reduce(group_key, vec![], None)
1544 }
1545
1546 /// Discards rows with a negative frequency.
1547 pub fn threshold(self) -> Self {
1548 if let MirRelationExpr::Threshold { .. } = &self {
1549 self
1550 } else {
1551 MirRelationExpr::Threshold {
1552 input: Box::new(self),
1553 }
1554 }
1555 }
1556
1557 /// Unions together any number inputs.
1558 ///
1559 /// If `inputs` is empty, then an empty relation of type `typ` is
1560 /// constructed.
1561 pub fn union_many(mut inputs: Vec<Self>, typ: SqlRelationType) -> Self {
1562 // Deconstruct `inputs` as `Union`s and reconstitute.
1563 let mut flat_inputs = Vec::with_capacity(inputs.len());
1564 for input in inputs {
1565 if let MirRelationExpr::Union { base, inputs } = input {
1566 flat_inputs.push(*base);
1567 flat_inputs.extend(inputs);
1568 } else {
1569 flat_inputs.push(input);
1570 }
1571 }
1572 inputs = flat_inputs;
1573 if inputs.len() == 0 {
1574 MirRelationExpr::Constant {
1575 rows: Ok(vec![]),
1576 typ,
1577 }
1578 } else if inputs.len() == 1 {
1579 inputs.into_element()
1580 } else {
1581 MirRelationExpr::Union {
1582 base: Box::new(inputs.remove(0)),
1583 inputs,
1584 }
1585 }
1586 }
1587
1588 /// Produces one collection where each row is present with the sum of its frequencies in each input.
1589 pub fn union(self, other: Self) -> Self {
1590 // Deconstruct `self` and `other` as `Union`s and reconstitute.
1591 let mut flat_inputs = Vec::with_capacity(2);
1592 if let MirRelationExpr::Union { base, inputs } = self {
1593 flat_inputs.push(*base);
1594 flat_inputs.extend(inputs);
1595 } else {
1596 flat_inputs.push(self);
1597 }
1598 if let MirRelationExpr::Union { base, inputs } = other {
1599 flat_inputs.push(*base);
1600 flat_inputs.extend(inputs);
1601 } else {
1602 flat_inputs.push(other);
1603 }
1604
1605 MirRelationExpr::Union {
1606 base: Box::new(flat_inputs.remove(0)),
1607 inputs: flat_inputs,
1608 }
1609 }
1610
1611 /// Arranges the collection by the specified columns
1612 pub fn arrange_by(self, keys: &[Vec<MirScalarExpr>]) -> Self {
1613 MirRelationExpr::ArrangeBy {
1614 input: Box::new(self),
1615 keys: keys.to_owned(),
1616 }
1617 }
1618
1619 /// Indicates if this is a constant empty collection.
1620 ///
1621 /// A false value does not mean the collection is known to be non-empty,
1622 /// only that we cannot currently determine that it is statically empty.
1623 pub fn is_empty(&self) -> bool {
1624 if let Some((Ok(rows), ..)) = self.as_const() {
1625 rows.is_empty()
1626 } else {
1627 false
1628 }
1629 }
1630
1631 /// If the expression is a negated project, return the input and the projection.
1632 pub fn is_negated_project(&self) -> Option<(&MirRelationExpr, &[usize])> {
1633 if let MirRelationExpr::Negate { input } = self {
1634 if let MirRelationExpr::Project { input, outputs } = &**input {
1635 return Some((&**input, outputs));
1636 }
1637 }
1638 if let MirRelationExpr::Project { input, outputs } = self {
1639 if let MirRelationExpr::Negate { input } = &**input {
1640 return Some((&**input, outputs));
1641 }
1642 }
1643 None
1644 }
1645
1646 /// Pretty-print this [MirRelationExpr] to a string.
1647 pub fn pretty(&self) -> String {
1648 let config = ExplainConfig::default();
1649 self.explain(&config, None)
1650 }
1651
1652 /// Pretty-print this [MirRelationExpr] to a string using a custom
1653 /// [ExplainConfig] and an optionally provided [ExprHumanizer].
1654 pub fn explain(&self, config: &ExplainConfig, humanizer: Option<&dyn ExprHumanizer>) -> String {
1655 text_string_at(self, || PlanRenderingContext {
1656 indent: Indent::default(),
1657 humanizer: humanizer.unwrap_or(&DummyHumanizer),
1658 annotations: BTreeMap::default(),
1659 config,
1660 })
1661 }
1662
1663 /// Take ownership of `self`, leaving an empty `MirRelationExpr::Constant` with the optionally
1664 /// given scalar types. The given scalar types should be `base_eq` with the types that `typ()`
1665 /// would find. Keys and nullability are ignored in the given `SqlRelationType`, and instead we set
1666 /// the best possible key and nullability, since we are making an empty collection.
1667 ///
1668 /// If `typ` is not given, then this calls `.typ()` (which is possibly expensive) to determine
1669 /// the correct type.
1670 pub fn take_safely(&mut self, typ: Option<SqlRelationType>) -> MirRelationExpr {
1671 if let Some(typ) = &typ {
1672 soft_assert_no_log!(
1673 self.typ()
1674 .column_types
1675 .iter()
1676 .zip_eq(typ.column_types.iter())
1677 .all(|(t1, t2)| t1.scalar_type.base_eq(&t2.scalar_type))
1678 );
1679 }
1680 let mut typ = typ.unwrap_or_else(|| self.typ());
1681 typ.keys = vec![vec![]];
1682 for ct in typ.column_types.iter_mut() {
1683 ct.nullable = false;
1684 }
1685 std::mem::replace(
1686 self,
1687 MirRelationExpr::Constant {
1688 rows: Ok(vec![]),
1689 typ,
1690 },
1691 )
1692 }
1693
1694 /// Take ownership of `self`, leaving an empty `MirRelationExpr::Constant` with the given scalar
1695 /// types. Nullability is ignored in the given `SqlColumnType`s, and instead we set the best
1696 /// possible nullability, since we are making an empty collection.
1697 pub fn take_safely_with_col_types(&mut self, typ: Vec<SqlColumnType>) -> MirRelationExpr {
1698 self.take_safely(Some(SqlRelationType::new(typ)))
1699 }
1700
1701 /// Take ownership of `self`, leaving an empty `MirRelationExpr::Constant` with an **incorrect** type.
1702 ///
1703 /// This should only be used if `self` is about to be dropped or otherwise overwritten.
1704 pub fn take_dangerous(&mut self) -> MirRelationExpr {
1705 let empty = MirRelationExpr::Constant {
1706 rows: Ok(vec![]),
1707 typ: SqlRelationType::new(Vec::new()),
1708 };
1709 std::mem::replace(self, empty)
1710 }
1711
1712 /// Replaces `self` with some logic applied to `self`.
1713 pub fn replace_using<F>(&mut self, logic: F)
1714 where
1715 F: FnOnce(MirRelationExpr) -> MirRelationExpr,
1716 {
1717 let empty = MirRelationExpr::Constant {
1718 rows: Ok(vec![]),
1719 typ: SqlRelationType::new(Vec::new()),
1720 };
1721 let expr = std::mem::replace(self, empty);
1722 *self = logic(expr);
1723 }
1724
1725 /// Store `self` in a `Let` and pass the corresponding `Get` to `body`.
1726 pub fn let_in<Body, E>(self, id_gen: &mut IdGen, body: Body) -> Result<MirRelationExpr, E>
1727 where
1728 Body: FnOnce(&mut IdGen, MirRelationExpr) -> Result<MirRelationExpr, E>,
1729 {
1730 if let MirRelationExpr::Get { .. } = self {
1731 // already done
1732 body(id_gen, self)
1733 } else {
1734 let id = LocalId::new(id_gen.allocate_id());
1735 let get = MirRelationExpr::Get {
1736 id: Id::Local(id),
1737 typ: self.typ(),
1738 access_strategy: AccessStrategy::UnknownOrLocal,
1739 };
1740 let body = (body)(id_gen, get)?;
1741 Ok(MirRelationExpr::Let {
1742 id,
1743 value: Box::new(self),
1744 body: Box::new(body),
1745 })
1746 }
1747 }
1748
1749 /// Return every row in `self` that does not have a matching row in the first columns of `keys_and_values`, using `default` to fill in the remaining columns
1750 /// (If `default` is a row of nulls, this is the 'outer' part of LEFT OUTER JOIN)
1751 pub fn anti_lookup<E>(
1752 self,
1753 id_gen: &mut IdGen,
1754 keys_and_values: MirRelationExpr,
1755 default: Vec<(Datum, SqlScalarType)>,
1756 ) -> Result<MirRelationExpr, E> {
1757 let (data, column_types): (Vec<_>, Vec<_>) = default
1758 .into_iter()
1759 .map(|(datum, scalar_type)| (datum, scalar_type.nullable(datum.is_null())))
1760 .unzip();
1761 assert_eq!(keys_and_values.arity() - self.arity(), data.len());
1762 self.let_in(id_gen, |_id_gen, get_keys| {
1763 let get_keys_arity = get_keys.arity();
1764 Ok(MirRelationExpr::join(
1765 vec![
1766 // all the missing keys (with count 1)
1767 keys_and_values
1768 .distinct_by((0..get_keys_arity).collect())
1769 .negate()
1770 .union(get_keys.clone().distinct()),
1771 // join with keys to get the correct counts
1772 get_keys.clone(),
1773 ],
1774 (0..get_keys_arity).map(|i| vec![(0, i), (1, i)]).collect(),
1775 )
1776 // get rid of the extra copies of columns from keys
1777 .project((0..get_keys_arity).collect())
1778 // This join is logically equivalent to
1779 // `.map(<default_expr>)`, but using a join allows for
1780 // potential predicate pushdown and elision in the
1781 // optimizer.
1782 .product(MirRelationExpr::constant(
1783 vec![data],
1784 SqlRelationType::new(column_types),
1785 )))
1786 })
1787 }
1788
1789 /// Return:
1790 /// * every row in keys_and_values
1791 /// * every row in `self` that does not have a matching row in the first columns of
1792 /// `keys_and_values`, using `default` to fill in the remaining columns
1793 /// (This is LEFT OUTER JOIN if:
1794 /// 1) `default` is a row of null
1795 /// 2) matching rows in `keys_and_values` and `self` have the same multiplicity.)
1796 pub fn lookup<E>(
1797 self,
1798 id_gen: &mut IdGen,
1799 keys_and_values: MirRelationExpr,
1800 default: Vec<(Datum<'static>, SqlScalarType)>,
1801 ) -> Result<MirRelationExpr, E> {
1802 keys_and_values.let_in(id_gen, |id_gen, get_keys_and_values| {
1803 Ok(get_keys_and_values.clone().union(self.anti_lookup(
1804 id_gen,
1805 get_keys_and_values,
1806 default,
1807 )?))
1808 })
1809 }
1810
1811 /// True iff the expression contains a `NullaryFunc::MzLogicalTimestamp`.
1812 pub fn contains_temporal(&self) -> bool {
1813 let mut contains = false;
1814 self.visit_scalars(&mut |e| contains = contains || e.contains_temporal());
1815 contains
1816 }
1817
1818 /// Fallible visitor for the [`MirScalarExpr`]s directly owned by this relation expression.
1819 ///
1820 /// The `f` visitor should not recursively descend into owned [`MirRelationExpr`]s.
1821 pub fn try_visit_scalars_mut1<F, E>(&mut self, f: &mut F) -> Result<(), E>
1822 where
1823 F: FnMut(&mut MirScalarExpr) -> Result<(), E>,
1824 {
1825 use MirRelationExpr::*;
1826 match self {
1827 Map { scalars, .. } => {
1828 for s in scalars {
1829 f(s)?;
1830 }
1831 }
1832 Filter { predicates, .. } => {
1833 for p in predicates {
1834 f(p)?;
1835 }
1836 }
1837 FlatMap { exprs, .. } => {
1838 for expr in exprs {
1839 f(expr)?;
1840 }
1841 }
1842 Join {
1843 inputs: _,
1844 equivalences,
1845 implementation,
1846 } => {
1847 for equivalence in equivalences {
1848 for expr in equivalence {
1849 f(expr)?;
1850 }
1851 }
1852 match implementation {
1853 JoinImplementation::Differential((_, start_key, _), order) => {
1854 if let Some(start_key) = start_key {
1855 for k in start_key {
1856 f(k)?;
1857 }
1858 }
1859 for (_, lookup_key, _) in order {
1860 for k in lookup_key {
1861 f(k)?;
1862 }
1863 }
1864 }
1865 JoinImplementation::DeltaQuery(paths) => {
1866 for path in paths {
1867 for (_, lookup_key, _) in path {
1868 for k in lookup_key {
1869 f(k)?;
1870 }
1871 }
1872 }
1873 }
1874 JoinImplementation::IndexedFilter(_coll_id, _idx_id, index_key, _) => {
1875 for k in index_key {
1876 f(k)?;
1877 }
1878 }
1879 JoinImplementation::Unimplemented => {} // No scalar exprs
1880 }
1881 }
1882 ArrangeBy { keys, .. } => {
1883 for key in keys {
1884 for s in key {
1885 f(s)?;
1886 }
1887 }
1888 }
1889 Reduce {
1890 group_key,
1891 aggregates,
1892 ..
1893 } => {
1894 for s in group_key {
1895 f(s)?;
1896 }
1897 for agg in aggregates {
1898 f(&mut agg.expr)?;
1899 }
1900 }
1901 TopK { limit, .. } => {
1902 if let Some(s) = limit {
1903 f(s)?;
1904 }
1905 }
1906 Constant { .. }
1907 | Get { .. }
1908 | Let { .. }
1909 | LetRec { .. }
1910 | Project { .. }
1911 | Negate { .. }
1912 | Threshold { .. }
1913 | Union { .. } => (),
1914 }
1915 Ok(())
1916 }
1917
1918 /// Fallible mutable visitor for the [`MirScalarExpr`]s in the [`MirRelationExpr`] subtree
1919 /// rooted at `self`.
1920 ///
1921 /// Note that this does not recurse into [`MirRelationExpr`] subtrees within [`MirScalarExpr`]
1922 /// nodes.
1923 pub fn try_visit_scalars_mut<F, E>(&mut self, f: &mut F) -> Result<(), E>
1924 where
1925 F: FnMut(&mut MirScalarExpr) -> Result<(), E>,
1926 E: From<RecursionLimitError>,
1927 {
1928 self.try_visit_mut_post(&mut |expr| expr.try_visit_scalars_mut1(f))
1929 }
1930
1931 /// Infallible mutable visitor for the [`MirScalarExpr`]s in the [`MirRelationExpr`] subtree
1932 /// rooted at `self`.
1933 ///
1934 /// Note that this does not recurse into [`MirRelationExpr`] subtrees within [`MirScalarExpr`]
1935 /// nodes.
1936 pub fn visit_scalars_mut<F>(&mut self, f: &mut F)
1937 where
1938 F: FnMut(&mut MirScalarExpr),
1939 {
1940 self.try_visit_scalars_mut(&mut |s| {
1941 f(s);
1942 Ok::<_, RecursionLimitError>(())
1943 })
1944 .expect("Unexpected error in `visit_scalars_mut` call");
1945 }
1946
1947 /// Fallible visitor for the [`MirScalarExpr`]s directly owned by this relation expression.
1948 ///
1949 /// The `f` visitor should not recursively descend into owned [`MirRelationExpr`]s.
1950 pub fn try_visit_scalars_1<F, E>(&self, f: &mut F) -> Result<(), E>
1951 where
1952 F: FnMut(&MirScalarExpr) -> Result<(), E>,
1953 {
1954 use MirRelationExpr::*;
1955 match self {
1956 Map { scalars, .. } => {
1957 for s in scalars {
1958 f(s)?;
1959 }
1960 }
1961 Filter { predicates, .. } => {
1962 for p in predicates {
1963 f(p)?;
1964 }
1965 }
1966 FlatMap { exprs, .. } => {
1967 for expr in exprs {
1968 f(expr)?;
1969 }
1970 }
1971 Join {
1972 inputs: _,
1973 equivalences,
1974 implementation,
1975 } => {
1976 for equivalence in equivalences {
1977 for expr in equivalence {
1978 f(expr)?;
1979 }
1980 }
1981 match implementation {
1982 JoinImplementation::Differential((_, start_key, _), order) => {
1983 if let Some(start_key) = start_key {
1984 for k in start_key {
1985 f(k)?;
1986 }
1987 }
1988 for (_, lookup_key, _) in order {
1989 for k in lookup_key {
1990 f(k)?;
1991 }
1992 }
1993 }
1994 JoinImplementation::DeltaQuery(paths) => {
1995 for path in paths {
1996 for (_, lookup_key, _) in path {
1997 for k in lookup_key {
1998 f(k)?;
1999 }
2000 }
2001 }
2002 }
2003 JoinImplementation::IndexedFilter(_coll_id, _idx_id, index_key, _) => {
2004 for k in index_key {
2005 f(k)?;
2006 }
2007 }
2008 JoinImplementation::Unimplemented => {} // No scalar exprs
2009 }
2010 }
2011 ArrangeBy { keys, .. } => {
2012 for key in keys {
2013 for s in key {
2014 f(s)?;
2015 }
2016 }
2017 }
2018 Reduce {
2019 group_key,
2020 aggregates,
2021 ..
2022 } => {
2023 for s in group_key {
2024 f(s)?;
2025 }
2026 for agg in aggregates {
2027 f(&agg.expr)?;
2028 }
2029 }
2030 TopK { limit, .. } => {
2031 if let Some(s) = limit {
2032 f(s)?;
2033 }
2034 }
2035 Constant { .. }
2036 | Get { .. }
2037 | Let { .. }
2038 | LetRec { .. }
2039 | Project { .. }
2040 | Negate { .. }
2041 | Threshold { .. }
2042 | Union { .. } => (),
2043 }
2044 Ok(())
2045 }
2046
2047 /// Fallible immutable visitor for the [`MirScalarExpr`]s in the [`MirRelationExpr`] subtree
2048 /// rooted at `self`.
2049 ///
2050 /// Note that this does not recurse into [`MirRelationExpr`] subtrees within [`MirScalarExpr`]
2051 /// nodes.
2052 pub fn try_visit_scalars<F, E>(&self, f: &mut F) -> Result<(), E>
2053 where
2054 F: FnMut(&MirScalarExpr) -> Result<(), E>,
2055 E: From<RecursionLimitError>,
2056 {
2057 self.try_visit_post(&mut |expr| expr.try_visit_scalars_1(f))
2058 }
2059
2060 /// Infallible immutable visitor for the [`MirScalarExpr`]s in the [`MirRelationExpr`] subtree
2061 /// rooted at `self`.
2062 ///
2063 /// Note that this does not recurse into [`MirRelationExpr`] subtrees within [`MirScalarExpr`]
2064 /// nodes.
2065 pub fn visit_scalars<F>(&self, f: &mut F)
2066 where
2067 F: FnMut(&MirScalarExpr),
2068 {
2069 self.try_visit_scalars(&mut |s| {
2070 f(s);
2071 Ok::<_, RecursionLimitError>(())
2072 })
2073 .expect("Unexpected error in `visit_scalars` call");
2074 }
2075
2076 /// Clears the contents of `self` even if it's so deep that simply dropping it would cause a
2077 /// stack overflow in `drop_in_place`.
2078 ///
2079 /// Leaves `self` in an unusable state, so this should only be used if `self` is about to be
2080 /// dropped or otherwise overwritten.
2081 pub fn destroy_carefully(&mut self) {
2082 let mut todo = vec![self.take_dangerous()];
2083 while let Some(mut expr) = todo.pop() {
2084 for child in expr.children_mut() {
2085 todo.push(child.take_dangerous());
2086 }
2087 }
2088 }
2089
2090 /// Computes the size (total number of nodes) and maximum depth of a MirRelationExpr for
2091 /// debug printing purposes.
2092 pub fn debug_size_and_depth(&self) -> (usize, usize) {
2093 let mut size = 0;
2094 let mut max_depth = 0;
2095 let mut todo = vec![(self, 1)];
2096 while let Some((expr, depth)) = todo.pop() {
2097 size += 1;
2098 max_depth = max(max_depth, depth);
2099 todo.extend(expr.children().map(|c| (c, depth + 1)));
2100 }
2101 (size, max_depth)
2102 }
2103
2104 /// The MirRelationExpr is considered potentially expensive if and only if
2105 /// at least one of the following conditions is true:
2106 ///
2107 /// - It contains at least one FlatMap or a Reduce operator.
2108 /// - It contains at least one MirScalarExpr with a function call.
2109 ///
2110 /// !!!WARNING!!!: this method has an HirRelationExpr counterpart. The two
2111 /// should be kept in sync w.r.t. HIR ⇒ MIR lowering!
2112 pub fn could_run_expensive_function(&self) -> bool {
2113 let mut result = false;
2114 self.visit_pre(|e: &MirRelationExpr| {
2115 use MirRelationExpr::*;
2116 use MirScalarExpr::*;
2117 if let Err(_) = self.try_visit_scalars::<_, RecursionLimitError>(&mut |scalar| {
2118 result |= match scalar {
2119 Column(_, _) | Literal(_, _) | CallUnmaterializable(_) | If { .. } => false,
2120 // Function calls are considered expensive
2121 CallUnary { .. } | CallBinary { .. } | CallVariadic { .. } => true,
2122 };
2123 Ok(())
2124 }) {
2125 // Conservatively set `true` if on RecursionLimitError.
2126 result = true;
2127 }
2128 // FlatMap has a table function; Reduce has an aggregate function.
2129 // Other constructs use MirScalarExpr to run a function
2130 result |= matches!(e, FlatMap { .. } | Reduce { .. });
2131 });
2132 result
2133 }
2134
2135 /// Hash to an u64 using Rust's default Hasher. (Which is a somewhat slower, but better Hasher
2136 /// than what `Hashable::hashed` would give us.)
2137 pub fn hash_to_u64(&self) -> u64 {
2138 let mut h = DefaultHasher::new();
2139 self.hash(&mut h);
2140 h.finish()
2141 }
2142}
2143
2144// `LetRec` helpers
2145impl MirRelationExpr {
2146 /// True when `expr` contains a `LetRec` AST node.
2147 pub fn is_recursive(self: &MirRelationExpr) -> bool {
2148 let mut worklist = vec![self];
2149 while let Some(expr) = worklist.pop() {
2150 if let MirRelationExpr::LetRec { .. } = expr {
2151 return true;
2152 }
2153 worklist.extend(expr.children());
2154 }
2155 false
2156 }
2157
2158 /// Return the number of sub-expressions in the tree (including self).
2159 pub fn size(&self) -> usize {
2160 let mut size = 0;
2161 self.visit_pre(|_| size += 1);
2162 size
2163 }
2164
2165 /// Given the ids and values of a LetRec, it computes the subset of ids that are used across
2166 /// iterations. These are those ids that have a reference before they are defined, when reading
2167 /// all the bindings in order.
2168 ///
2169 /// For example:
2170 /// ```SQL
2171 /// WITH MUTUALLY RECURSIVE
2172 /// x(...) AS f(z),
2173 /// y(...) AS g(x),
2174 /// z(...) AS h(y)
2175 /// ...;
2176 /// ```
2177 /// Here, only `z` is returned, because `x` and `y` are referenced only within the same
2178 /// iteration.
2179 ///
2180 /// Note that if a binding references itself, that is also returned.
2181 pub fn recursive_ids(ids: &[LocalId], values: &[MirRelationExpr]) -> BTreeSet<LocalId> {
2182 let mut used_across_iterations = BTreeSet::new();
2183 let mut defined = BTreeSet::new();
2184 for (binding_id, value) in itertools::zip_eq(ids.iter(), values.iter()) {
2185 value.visit_pre(|expr| {
2186 if let MirRelationExpr::Get {
2187 id: Local(get_id), ..
2188 } = expr
2189 {
2190 // If we haven't seen a definition for it yet, then this will refer
2191 // to the previous iteration.
2192 // The `ids.contains` part of the condition is needed to exclude
2193 // those ids that are not really in this LetRec, but either an inner
2194 // or outer one.
2195 if !defined.contains(get_id) && ids.contains(get_id) {
2196 used_across_iterations.insert(*get_id);
2197 }
2198 }
2199 });
2200 defined.insert(*binding_id);
2201 }
2202 used_across_iterations
2203 }
2204
2205 /// Replaces `LetRec` nodes with a stack of `Let` nodes.
2206 ///
2207 /// In each `Let` binding, uses of `Get` in `value` that are not at strictly greater
2208 /// identifiers are rewritten to be the constant collection.
2209 /// This makes the computation perform exactly "one" iteration.
2210 ///
2211 /// This was used only temporarily while developing `LetRec`.
2212 pub fn make_nonrecursive(self: &mut MirRelationExpr) {
2213 let mut deadlist = BTreeSet::new();
2214 let mut worklist = vec![self];
2215 while let Some(expr) = worklist.pop() {
2216 if let MirRelationExpr::LetRec {
2217 ids,
2218 values,
2219 limits: _,
2220 body,
2221 } = expr
2222 {
2223 let ids_values = values
2224 .drain(..)
2225 .zip_eq(ids)
2226 .map(|(value, id)| (*id, value))
2227 .collect::<Vec<_>>();
2228 *expr = body.take_dangerous();
2229 for (id, mut value) in ids_values.into_iter().rev() {
2230 // Remove references to potentially recursive identifiers.
2231 deadlist.insert(id);
2232 value.visit_pre_mut(|e| {
2233 if let MirRelationExpr::Get {
2234 id: crate::Id::Local(id),
2235 typ,
2236 ..
2237 } = e
2238 {
2239 let typ = typ.clone();
2240 if deadlist.contains(id) {
2241 e.take_safely(Some(typ));
2242 }
2243 }
2244 });
2245 *expr = MirRelationExpr::Let {
2246 id,
2247 value: Box::new(value),
2248 body: Box::new(expr.take_dangerous()),
2249 };
2250 }
2251 worklist.push(expr);
2252 } else {
2253 worklist.extend(expr.children_mut().rev());
2254 }
2255 }
2256 }
2257
2258 /// For each Id `id'` referenced in `expr`, if it is larger or equal than `id`, then record in
2259 /// `expire_whens` that when `id'` is redefined, then we should expire the information that
2260 /// we are holding about `id`. Call `do_expirations` with `expire_whens` at each Id
2261 /// redefinition.
2262 ///
2263 /// IMPORTANT: Relies on the numbering of Ids to be what `renumber_bindings` gives.
2264 pub fn collect_expirations(
2265 id: LocalId,
2266 expr: &MirRelationExpr,
2267 expire_whens: &mut BTreeMap<LocalId, Vec<LocalId>>,
2268 ) {
2269 expr.visit_pre(|e| {
2270 if let MirRelationExpr::Get {
2271 id: Id::Local(referenced_id),
2272 ..
2273 } = e
2274 {
2275 // The following check needs `renumber_bindings` to have run recently
2276 if referenced_id >= &id {
2277 expire_whens
2278 .entry(*referenced_id)
2279 .or_insert_with(Vec::new)
2280 .push(id);
2281 }
2282 }
2283 });
2284 }
2285
2286 /// Call this function when `id` is redefined. It modifies `id_infos` by removing information
2287 /// about such Ids whose information depended on the earlier definition of `id`, according to
2288 /// `expire_whens`. Also modifies `expire_whens`: it removes the currently processed entry.
2289 pub fn do_expirations<I>(
2290 redefined_id: LocalId,
2291 expire_whens: &mut BTreeMap<LocalId, Vec<LocalId>>,
2292 id_infos: &mut BTreeMap<LocalId, I>,
2293 ) -> Vec<(LocalId, I)> {
2294 let mut expired_infos = Vec::new();
2295 if let Some(expirations) = expire_whens.remove(&redefined_id) {
2296 for expired_id in expirations.into_iter() {
2297 if let Some(offer) = id_infos.remove(&expired_id) {
2298 expired_infos.push((expired_id, offer));
2299 }
2300 }
2301 }
2302 expired_infos
2303 }
2304}
2305/// Augment non-nullability of columns, by observing either
2306/// 1. Predicates that explicitly test for null values, and
2307/// 2. Columns that if null would make a predicate be null.
2308pub fn non_nullable_columns(predicates: &[MirScalarExpr]) -> BTreeSet<usize> {
2309 let mut nonnull_required_columns = BTreeSet::new();
2310 for predicate in predicates {
2311 // Add any columns that being null would force the predicate to be null.
2312 // Should that happen, the row would be discarded.
2313 predicate.non_null_requirements(&mut nonnull_required_columns);
2314
2315 /*
2316 Test for explicit checks that a column is non-null.
2317
2318 This analysis is ad hoc, and will miss things:
2319
2320 materialize=> create table a(x int, y int);
2321 CREATE TABLE
2322 materialize=> explain with(types) select x from a where (y=x and y is not null) or x is not null;
2323 Optimized Plan
2324 --------------------------------------------------------------------------------------------------------
2325 Explained Query: +
2326 Project (#0) // { types: "(integer?)" } +
2327 Filter ((#0) IS NOT NULL OR ((#1) IS NOT NULL AND (#0 = #1))) // { types: "(integer?, integer?)" }+
2328 Get materialize.public.a // { types: "(integer?, integer?)" } +
2329 +
2330 Source materialize.public.a +
2331 filter=(((#0) IS NOT NULL OR ((#1) IS NOT NULL AND (#0 = #1)))) +
2332
2333 (1 row)
2334 */
2335
2336 if let MirScalarExpr::CallUnary {
2337 func: UnaryFunc::Not(scalar_func::Not),
2338 expr,
2339 } = predicate
2340 {
2341 if let MirScalarExpr::CallUnary {
2342 func: UnaryFunc::IsNull(scalar_func::IsNull),
2343 expr,
2344 } = &**expr
2345 {
2346 if let MirScalarExpr::Column(c, _name) = &**expr {
2347 nonnull_required_columns.insert(*c);
2348 }
2349 }
2350 }
2351 }
2352
2353 nonnull_required_columns
2354}
2355
2356impl CollectionPlan for MirRelationExpr {
2357 /// Collects the global collections that this MIR expression directly depends on, i.e., that it
2358 /// has a `Get` for. (It does _not_ traverse view definitions transitively.)
2359 ///
2360 /// !!!WARNING!!!: this method has an HirRelationExpr counterpart. The two
2361 /// should be kept in sync w.r.t. HIR ⇒ MIR lowering!
2362 fn depends_on_into(&self, out: &mut BTreeSet<GlobalId>) {
2363 if let MirRelationExpr::Get {
2364 id: Id::Global(id), ..
2365 } = self
2366 {
2367 out.insert(*id);
2368 }
2369 self.visit_children(|expr| expr.depends_on_into(out))
2370 }
2371}
2372
2373impl MirRelationExpr {
2374 /// Iterates through references to child expressions.
2375 pub fn children(&self) -> impl DoubleEndedIterator<Item = &Self> {
2376 let mut first = None;
2377 let mut second = None;
2378 let mut rest = None;
2379 let mut last = None;
2380
2381 use MirRelationExpr::*;
2382 match self {
2383 Constant { .. } | Get { .. } => (),
2384 Let { value, body, .. } => {
2385 first = Some(&**value);
2386 second = Some(&**body);
2387 }
2388 LetRec { values, body, .. } => {
2389 rest = Some(values);
2390 last = Some(&**body);
2391 }
2392 Project { input, .. }
2393 | Map { input, .. }
2394 | FlatMap { input, .. }
2395 | Filter { input, .. }
2396 | Reduce { input, .. }
2397 | TopK { input, .. }
2398 | Negate { input }
2399 | Threshold { input }
2400 | ArrangeBy { input, .. } => {
2401 first = Some(&**input);
2402 }
2403 Join { inputs, .. } => {
2404 rest = Some(inputs);
2405 }
2406 Union { base, inputs } => {
2407 first = Some(&**base);
2408 rest = Some(inputs);
2409 }
2410 }
2411
2412 first
2413 .into_iter()
2414 .chain(second)
2415 .chain(rest.into_iter().flatten())
2416 .chain(last)
2417 }
2418
2419 /// Iterates through mutable references to child expressions.
2420 pub fn children_mut(&mut self) -> impl DoubleEndedIterator<Item = &mut Self> {
2421 let mut first = None;
2422 let mut second = None;
2423 let mut rest = None;
2424 let mut last = None;
2425
2426 use MirRelationExpr::*;
2427 match self {
2428 Constant { .. } | Get { .. } => (),
2429 Let { value, body, .. } => {
2430 first = Some(&mut **value);
2431 second = Some(&mut **body);
2432 }
2433 LetRec { values, body, .. } => {
2434 rest = Some(values);
2435 last = Some(&mut **body);
2436 }
2437 Project { input, .. }
2438 | Map { input, .. }
2439 | FlatMap { input, .. }
2440 | Filter { input, .. }
2441 | Reduce { input, .. }
2442 | TopK { input, .. }
2443 | Negate { input }
2444 | Threshold { input }
2445 | ArrangeBy { input, .. } => {
2446 first = Some(&mut **input);
2447 }
2448 Join { inputs, .. } => {
2449 rest = Some(inputs);
2450 }
2451 Union { base, inputs } => {
2452 first = Some(&mut **base);
2453 rest = Some(inputs);
2454 }
2455 }
2456
2457 first
2458 .into_iter()
2459 .chain(second)
2460 .chain(rest.into_iter().flatten())
2461 .chain(last)
2462 }
2463
2464 /// Iterative pre-order visitor.
2465 pub fn visit_pre<'a, F: FnMut(&'a Self)>(&'a self, mut f: F) {
2466 let mut worklist = vec![self];
2467 while let Some(expr) = worklist.pop() {
2468 f(expr);
2469 worklist.extend(expr.children().rev());
2470 }
2471 }
2472
2473 /// Iterative pre-order visitor.
2474 pub fn visit_pre_mut<F: FnMut(&mut Self)>(&mut self, mut f: F) {
2475 let mut worklist = vec![self];
2476 while let Some(expr) = worklist.pop() {
2477 f(expr);
2478 worklist.extend(expr.children_mut().rev());
2479 }
2480 }
2481
2482 /// Return a vector of references to the subtrees of this expression
2483 /// in post-visit order (the last element is `&self`).
2484 pub fn post_order_vec(&self) -> Vec<&Self> {
2485 let mut stack = vec![self];
2486 let mut result = vec![];
2487 while let Some(expr) = stack.pop() {
2488 result.push(expr);
2489 stack.extend(expr.children());
2490 }
2491 result.reverse();
2492 result
2493 }
2494}
2495
2496impl VisitChildren<Self> for MirRelationExpr {
2497 fn visit_children<F>(&self, mut f: F)
2498 where
2499 F: FnMut(&Self),
2500 {
2501 for child in self.children() {
2502 f(child)
2503 }
2504 }
2505
2506 fn visit_mut_children<F>(&mut self, mut f: F)
2507 where
2508 F: FnMut(&mut Self),
2509 {
2510 for child in self.children_mut() {
2511 f(child)
2512 }
2513 }
2514
2515 fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
2516 where
2517 F: FnMut(&Self) -> Result<(), E>,
2518 E: From<RecursionLimitError>,
2519 {
2520 for child in self.children() {
2521 f(child)?
2522 }
2523 Ok(())
2524 }
2525
2526 fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
2527 where
2528 F: FnMut(&mut Self) -> Result<(), E>,
2529 E: From<RecursionLimitError>,
2530 {
2531 for child in self.children_mut() {
2532 f(child)?
2533 }
2534 Ok(())
2535 }
2536}
2537
2538/// Specification for an ordering by a column.
2539#[derive(
2540 Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize, Hash, MzReflect,
2541)]
2542pub struct ColumnOrder {
2543 /// The column index.
2544 pub column: usize,
2545 /// Whether to sort in descending order.
2546 #[serde(default)]
2547 pub desc: bool,
2548 /// Whether to sort nulls last.
2549 #[serde(default)]
2550 pub nulls_last: bool,
2551}
2552
2553impl Columnation for ColumnOrder {
2554 type InnerRegion = CopyRegion<Self>;
2555}
2556
2557impl<'a, M> fmt::Display for HumanizedExpr<'a, ColumnOrder, M>
2558where
2559 M: HumanizerMode,
2560{
2561 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
2562 // If you modify this, then please also attend to Display for ColumnOrderWithExpr!
2563 write!(
2564 f,
2565 "{} {} {}",
2566 self.child(&self.expr.column),
2567 if self.expr.desc { "desc" } else { "asc" },
2568 if self.expr.nulls_last {
2569 "nulls_last"
2570 } else {
2571 "nulls_first"
2572 },
2573 )
2574 }
2575}
2576
2577/// Describes an aggregation expression.
2578#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize, Hash, MzReflect)]
2579pub struct AggregateExpr {
2580 /// Names the aggregation function.
2581 pub func: AggregateFunc,
2582 /// An expression which extracts from each row the input to `func`.
2583 pub expr: MirScalarExpr,
2584 /// Should the aggregation be applied only to distinct results in each group.
2585 #[serde(default)]
2586 pub distinct: bool,
2587}
2588
2589impl AggregateExpr {
2590 /// Computes the type of this `AggregateExpr`.
2591 pub fn typ(&self, column_types: &[SqlColumnType]) -> SqlColumnType {
2592 self.func.output_type(self.expr.typ(column_types))
2593 }
2594
2595 /// Computes the type of this `AggregateExpr`.
2596 pub fn repr_typ(&self, column_types: &[ReprColumnType]) -> ReprColumnType {
2597 ReprColumnType::from(
2598 &self
2599 .func
2600 .output_type(SqlColumnType::from_repr(&self.expr.repr_typ(column_types))),
2601 )
2602 }
2603
2604 /// Returns whether the expression has a constant result.
2605 pub fn is_constant(&self) -> bool {
2606 match self.func {
2607 AggregateFunc::MaxInt16
2608 | AggregateFunc::MaxInt32
2609 | AggregateFunc::MaxInt64
2610 | AggregateFunc::MaxUInt16
2611 | AggregateFunc::MaxUInt32
2612 | AggregateFunc::MaxUInt64
2613 | AggregateFunc::MaxMzTimestamp
2614 | AggregateFunc::MaxFloat32
2615 | AggregateFunc::MaxFloat64
2616 | AggregateFunc::MaxBool
2617 | AggregateFunc::MaxString
2618 | AggregateFunc::MaxDate
2619 | AggregateFunc::MaxTimestamp
2620 | AggregateFunc::MaxTimestampTz
2621 | AggregateFunc::MinInt16
2622 | AggregateFunc::MinInt32
2623 | AggregateFunc::MinInt64
2624 | AggregateFunc::MinUInt16
2625 | AggregateFunc::MinUInt32
2626 | AggregateFunc::MinUInt64
2627 | AggregateFunc::MinMzTimestamp
2628 | AggregateFunc::MinFloat32
2629 | AggregateFunc::MinFloat64
2630 | AggregateFunc::MinBool
2631 | AggregateFunc::MinString
2632 | AggregateFunc::MinDate
2633 | AggregateFunc::MinTimestamp
2634 | AggregateFunc::MinTimestampTz
2635 | AggregateFunc::Any
2636 | AggregateFunc::All
2637 | AggregateFunc::Dummy => self.expr.is_literal(),
2638 AggregateFunc::Count => self.expr.is_literal_null(),
2639 _ => self.expr.is_literal_err(),
2640 }
2641 }
2642
2643 /// Returns an expression that computes `self` on a group that has exactly one row.
2644 /// Instead of performing a `Reduce` with `self`, one can perform a `Map` with the expression
2645 /// returned by `on_unique`, which is cheaper. (See `ReduceElision`.)
2646 pub fn on_unique(&self, input_type: &[SqlColumnType]) -> MirScalarExpr {
2647 match &self.func {
2648 // Count is one if non-null, and zero if null.
2649 AggregateFunc::Count => self
2650 .expr
2651 .clone()
2652 .call_unary(UnaryFunc::IsNull(crate::func::IsNull))
2653 .if_then_else(
2654 MirScalarExpr::literal_ok(Datum::Int64(0), SqlScalarType::Int64),
2655 MirScalarExpr::literal_ok(Datum::Int64(1), SqlScalarType::Int64),
2656 ),
2657
2658 // SumInt16 takes Int16s as input, but outputs Int64s.
2659 AggregateFunc::SumInt16 => self
2660 .expr
2661 .clone()
2662 .call_unary(UnaryFunc::CastInt16ToInt64(scalar_func::CastInt16ToInt64)),
2663
2664 // SumInt32 takes Int32s as input, but outputs Int64s.
2665 AggregateFunc::SumInt32 => self
2666 .expr
2667 .clone()
2668 .call_unary(UnaryFunc::CastInt32ToInt64(scalar_func::CastInt32ToInt64)),
2669
2670 // SumInt64 takes Int64s as input, but outputs numerics.
2671 AggregateFunc::SumInt64 => self.expr.clone().call_unary(UnaryFunc::CastInt64ToNumeric(
2672 scalar_func::CastInt64ToNumeric(Some(NumericMaxScale::ZERO)),
2673 )),
2674
2675 // SumUInt16 takes UInt16s as input, but outputs UInt64s.
2676 AggregateFunc::SumUInt16 => self.expr.clone().call_unary(
2677 UnaryFunc::CastUint16ToUint64(scalar_func::CastUint16ToUint64),
2678 ),
2679
2680 // SumUInt32 takes UInt32s as input, but outputs UInt64s.
2681 AggregateFunc::SumUInt32 => self.expr.clone().call_unary(
2682 UnaryFunc::CastUint32ToUint64(scalar_func::CastUint32ToUint64),
2683 ),
2684
2685 // SumUInt64 takes UInt64s as input, but outputs numerics.
2686 AggregateFunc::SumUInt64 => {
2687 self.expr.clone().call_unary(UnaryFunc::CastUint64ToNumeric(
2688 scalar_func::CastUint64ToNumeric(Some(NumericMaxScale::ZERO)),
2689 ))
2690 }
2691
2692 // JsonbAgg takes _anything_ as input, but must output a Jsonb array.
2693 AggregateFunc::JsonbAgg { .. } => MirScalarExpr::CallVariadic {
2694 func: VariadicFunc::JsonbBuildArray,
2695 exprs: vec![
2696 self.expr
2697 .clone()
2698 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0))),
2699 ],
2700 },
2701
2702 // JsonbAgg takes _anything_ as input, but must output a Jsonb object.
2703 AggregateFunc::JsonbObjectAgg { .. } => {
2704 let record = self
2705 .expr
2706 .clone()
2707 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2708 MirScalarExpr::CallVariadic {
2709 func: VariadicFunc::JsonbBuildObject,
2710 exprs: (0..2)
2711 .map(|i| {
2712 record
2713 .clone()
2714 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(i)))
2715 })
2716 .collect(),
2717 }
2718 }
2719
2720 AggregateFunc::MapAgg { value_type, .. } => {
2721 let record = self
2722 .expr
2723 .clone()
2724 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2725 MirScalarExpr::CallVariadic {
2726 func: VariadicFunc::MapBuild {
2727 value_type: value_type.clone(),
2728 },
2729 exprs: (0..2)
2730 .map(|i| {
2731 record
2732 .clone()
2733 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(i)))
2734 })
2735 .collect(),
2736 }
2737 }
2738
2739 // StringAgg takes nested records of strings and outputs a string
2740 AggregateFunc::StringAgg { .. } => self
2741 .expr
2742 .clone()
2743 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)))
2744 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0))),
2745
2746 // ListConcat and ArrayConcat take a single level of records and output a list containing exactly 1 element
2747 AggregateFunc::ListConcat { .. } | AggregateFunc::ArrayConcat { .. } => self
2748 .expr
2749 .clone()
2750 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0))),
2751
2752 // RowNumber, Rank, DenseRank take a list of records and output a list containing exactly 1 element
2753 AggregateFunc::RowNumber { .. } => {
2754 self.on_unique_ranking_window_funcs(input_type, "?row_number?")
2755 }
2756 AggregateFunc::Rank { .. } => self.on_unique_ranking_window_funcs(input_type, "?rank?"),
2757 AggregateFunc::DenseRank { .. } => {
2758 self.on_unique_ranking_window_funcs(input_type, "?dense_rank?")
2759 }
2760
2761 // The input type for LagLead is ((OriginalRow, (InputValue, Offset, Default)), OrderByExprs...)
2762 AggregateFunc::LagLead { lag_lead, .. } => {
2763 let tuple = self
2764 .expr
2765 .clone()
2766 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2767
2768 // Get the overall return type
2769 let return_type_with_orig_row = self
2770 .typ(input_type)
2771 .scalar_type
2772 .unwrap_list_element_type()
2773 .clone();
2774 let lag_lead_return_type =
2775 return_type_with_orig_row.unwrap_record_element_type()[0].clone();
2776
2777 // Extract the original row
2778 let original_row = tuple
2779 .clone()
2780 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2781
2782 // Extract the encoded args
2783 let encoded_args =
2784 tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
2785
2786 let (result_expr, column_name) =
2787 Self::on_unique_lag_lead(lag_lead, encoded_args, lag_lead_return_type.clone());
2788
2789 MirScalarExpr::CallVariadic {
2790 func: VariadicFunc::ListCreate {
2791 elem_type: return_type_with_orig_row,
2792 },
2793 exprs: vec![MirScalarExpr::CallVariadic {
2794 func: VariadicFunc::RecordCreate {
2795 field_names: vec![column_name, ColumnName::from("?record?")],
2796 },
2797 exprs: vec![result_expr, original_row],
2798 }],
2799 }
2800 }
2801
2802 // The input type for FirstValue is ((OriginalRow, InputValue), OrderByExprs...)
2803 AggregateFunc::FirstValue { window_frame, .. } => {
2804 let tuple = self
2805 .expr
2806 .clone()
2807 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2808
2809 // Get the overall return type
2810 let return_type_with_orig_row = self
2811 .typ(input_type)
2812 .scalar_type
2813 .unwrap_list_element_type()
2814 .clone();
2815 let first_value_return_type =
2816 return_type_with_orig_row.unwrap_record_element_type()[0].clone();
2817
2818 // Extract the original row
2819 let original_row = tuple
2820 .clone()
2821 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2822
2823 // Extract the input value
2824 let arg = tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
2825
2826 let (result_expr, column_name) = Self::on_unique_first_value_last_value(
2827 window_frame,
2828 arg,
2829 first_value_return_type,
2830 );
2831
2832 MirScalarExpr::CallVariadic {
2833 func: VariadicFunc::ListCreate {
2834 elem_type: return_type_with_orig_row,
2835 },
2836 exprs: vec![MirScalarExpr::CallVariadic {
2837 func: VariadicFunc::RecordCreate {
2838 field_names: vec![column_name, ColumnName::from("?record?")],
2839 },
2840 exprs: vec![result_expr, original_row],
2841 }],
2842 }
2843 }
2844
2845 // The input type for LastValue is ((OriginalRow, InputValue), OrderByExprs...)
2846 AggregateFunc::LastValue { window_frame, .. } => {
2847 let tuple = self
2848 .expr
2849 .clone()
2850 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2851
2852 // Get the overall return type
2853 let return_type_with_orig_row = self
2854 .typ(input_type)
2855 .scalar_type
2856 .unwrap_list_element_type()
2857 .clone();
2858 let last_value_return_type =
2859 return_type_with_orig_row.unwrap_record_element_type()[0].clone();
2860
2861 // Extract the original row
2862 let original_row = tuple
2863 .clone()
2864 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2865
2866 // Extract the input value
2867 let arg = tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
2868
2869 let (result_expr, column_name) = Self::on_unique_first_value_last_value(
2870 window_frame,
2871 arg,
2872 last_value_return_type,
2873 );
2874
2875 MirScalarExpr::CallVariadic {
2876 func: VariadicFunc::ListCreate {
2877 elem_type: return_type_with_orig_row,
2878 },
2879 exprs: vec![MirScalarExpr::CallVariadic {
2880 func: VariadicFunc::RecordCreate {
2881 field_names: vec![column_name, ColumnName::from("?record?")],
2882 },
2883 exprs: vec![result_expr, original_row],
2884 }],
2885 }
2886 }
2887
2888 // The input type for window aggs is ((OriginalRow, InputValue), OrderByExprs...)
2889 // See an example MIR in `window_func_applied_to`.
2890 AggregateFunc::WindowAggregate {
2891 wrapped_aggregate,
2892 window_frame,
2893 order_by: _,
2894 } => {
2895 // TODO: deduplicate code between the various window function cases.
2896
2897 let tuple = self
2898 .expr
2899 .clone()
2900 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2901
2902 // Get the overall return type
2903 let return_type = self
2904 .typ(input_type)
2905 .scalar_type
2906 .unwrap_list_element_type()
2907 .clone();
2908 let window_agg_return_type = return_type.unwrap_record_element_type()[0].clone();
2909
2910 // Extract the original row
2911 let original_row = tuple
2912 .clone()
2913 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2914
2915 // Extract the input value
2916 let arg_expr = tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
2917
2918 let (result, column_name) = Self::on_unique_window_agg(
2919 window_frame,
2920 arg_expr,
2921 input_type,
2922 window_agg_return_type,
2923 wrapped_aggregate,
2924 );
2925
2926 MirScalarExpr::CallVariadic {
2927 func: VariadicFunc::ListCreate {
2928 elem_type: return_type,
2929 },
2930 exprs: vec![MirScalarExpr::CallVariadic {
2931 func: VariadicFunc::RecordCreate {
2932 field_names: vec![column_name, ColumnName::from("?record?")],
2933 },
2934 exprs: vec![result, original_row],
2935 }],
2936 }
2937 }
2938
2939 // The input type is ((OriginalRow, (Arg1, Arg2, ...)), OrderByExprs...)
2940 AggregateFunc::FusedWindowAggregate {
2941 wrapped_aggregates,
2942 order_by: _,
2943 window_frame,
2944 } => {
2945 // Throw away OrderByExprs
2946 let tuple = self
2947 .expr
2948 .clone()
2949 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2950
2951 // Extract the original row
2952 let original_row = tuple
2953 .clone()
2954 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2955
2956 // Extract the args of the fused call
2957 let all_args = tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
2958
2959 let return_type_with_orig_row = self
2960 .typ(input_type)
2961 .scalar_type
2962 .unwrap_list_element_type()
2963 .clone();
2964
2965 let all_func_return_types =
2966 return_type_with_orig_row.unwrap_record_element_type()[0].clone();
2967 let mut func_result_exprs = Vec::new();
2968 let mut col_names = Vec::new();
2969 for (idx, wrapped_aggr) in wrapped_aggregates.iter().enumerate() {
2970 let arg = all_args
2971 .clone()
2972 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(idx)));
2973 let return_type =
2974 all_func_return_types.unwrap_record_element_type()[idx].clone();
2975 let (result, column_name) = Self::on_unique_window_agg(
2976 window_frame,
2977 arg,
2978 input_type,
2979 return_type,
2980 wrapped_aggr,
2981 );
2982 func_result_exprs.push(result);
2983 col_names.push(column_name);
2984 }
2985
2986 MirScalarExpr::CallVariadic {
2987 func: VariadicFunc::ListCreate {
2988 elem_type: return_type_with_orig_row,
2989 },
2990 exprs: vec![MirScalarExpr::CallVariadic {
2991 func: VariadicFunc::RecordCreate {
2992 field_names: vec![
2993 ColumnName::from("?fused_window_aggr?"),
2994 ColumnName::from("?record?"),
2995 ],
2996 },
2997 exprs: vec![
2998 MirScalarExpr::CallVariadic {
2999 func: VariadicFunc::RecordCreate {
3000 field_names: col_names,
3001 },
3002 exprs: func_result_exprs,
3003 },
3004 original_row,
3005 ],
3006 }],
3007 }
3008 }
3009
3010 // The input type is ((OriginalRow, (Args1, Args2, ...)), OrderByExprs...)
3011 AggregateFunc::FusedValueWindowFunc {
3012 funcs,
3013 order_by: outer_order_by,
3014 } => {
3015 // Throw away OrderByExprs
3016 let tuple = self
3017 .expr
3018 .clone()
3019 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
3020
3021 // Extract the original row
3022 let original_row = tuple
3023 .clone()
3024 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
3025
3026 // Extract the encoded args of the fused call
3027 let all_encoded_args =
3028 tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
3029
3030 let return_type_with_orig_row = self
3031 .typ(input_type)
3032 .scalar_type
3033 .unwrap_list_element_type()
3034 .clone();
3035
3036 let all_func_return_types =
3037 return_type_with_orig_row.unwrap_record_element_type()[0].clone();
3038 let mut func_result_exprs = Vec::new();
3039 let mut col_names = Vec::new();
3040 for (idx, func) in funcs.iter().enumerate() {
3041 let args_for_func = all_encoded_args
3042 .clone()
3043 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(idx)));
3044 let return_type_for_func =
3045 all_func_return_types.unwrap_record_element_type()[idx].clone();
3046 let (result, column_name) = match func {
3047 AggregateFunc::LagLead {
3048 lag_lead,
3049 order_by,
3050 ignore_nulls: _,
3051 } => {
3052 assert_eq!(order_by, outer_order_by);
3053 Self::on_unique_lag_lead(lag_lead, args_for_func, return_type_for_func)
3054 }
3055 AggregateFunc::FirstValue {
3056 window_frame,
3057 order_by,
3058 } => {
3059 assert_eq!(order_by, outer_order_by);
3060 Self::on_unique_first_value_last_value(
3061 window_frame,
3062 args_for_func,
3063 return_type_for_func,
3064 )
3065 }
3066 AggregateFunc::LastValue {
3067 window_frame,
3068 order_by,
3069 } => {
3070 assert_eq!(order_by, outer_order_by);
3071 Self::on_unique_first_value_last_value(
3072 window_frame,
3073 args_for_func,
3074 return_type_for_func,
3075 )
3076 }
3077 _ => panic!("unknown function in FusedValueWindowFunc"),
3078 };
3079 func_result_exprs.push(result);
3080 col_names.push(column_name);
3081 }
3082
3083 MirScalarExpr::CallVariadic {
3084 func: VariadicFunc::ListCreate {
3085 elem_type: return_type_with_orig_row,
3086 },
3087 exprs: vec![MirScalarExpr::CallVariadic {
3088 func: VariadicFunc::RecordCreate {
3089 field_names: vec![
3090 ColumnName::from("?fused_value_window_func?"),
3091 ColumnName::from("?record?"),
3092 ],
3093 },
3094 exprs: vec![
3095 MirScalarExpr::CallVariadic {
3096 func: VariadicFunc::RecordCreate {
3097 field_names: col_names,
3098 },
3099 exprs: func_result_exprs,
3100 },
3101 original_row,
3102 ],
3103 }],
3104 }
3105 }
3106
3107 // All other variants should return the argument to the aggregation.
3108 AggregateFunc::MaxNumeric
3109 | AggregateFunc::MaxInt16
3110 | AggregateFunc::MaxInt32
3111 | AggregateFunc::MaxInt64
3112 | AggregateFunc::MaxUInt16
3113 | AggregateFunc::MaxUInt32
3114 | AggregateFunc::MaxUInt64
3115 | AggregateFunc::MaxMzTimestamp
3116 | AggregateFunc::MaxFloat32
3117 | AggregateFunc::MaxFloat64
3118 | AggregateFunc::MaxBool
3119 | AggregateFunc::MaxString
3120 | AggregateFunc::MaxDate
3121 | AggregateFunc::MaxTimestamp
3122 | AggregateFunc::MaxTimestampTz
3123 | AggregateFunc::MaxInterval
3124 | AggregateFunc::MaxTime
3125 | AggregateFunc::MinNumeric
3126 | AggregateFunc::MinInt16
3127 | AggregateFunc::MinInt32
3128 | AggregateFunc::MinInt64
3129 | AggregateFunc::MinUInt16
3130 | AggregateFunc::MinUInt32
3131 | AggregateFunc::MinUInt64
3132 | AggregateFunc::MinMzTimestamp
3133 | AggregateFunc::MinFloat32
3134 | AggregateFunc::MinFloat64
3135 | AggregateFunc::MinBool
3136 | AggregateFunc::MinString
3137 | AggregateFunc::MinDate
3138 | AggregateFunc::MinTimestamp
3139 | AggregateFunc::MinTimestampTz
3140 | AggregateFunc::MinInterval
3141 | AggregateFunc::MinTime
3142 | AggregateFunc::SumFloat32
3143 | AggregateFunc::SumFloat64
3144 | AggregateFunc::SumNumeric
3145 | AggregateFunc::Any
3146 | AggregateFunc::All
3147 | AggregateFunc::Dummy => self.expr.clone(),
3148 }
3149 }
3150
3151 /// `on_unique` for ROW_NUMBER, RANK, DENSE_RANK
3152 fn on_unique_ranking_window_funcs(
3153 &self,
3154 input_type: &[SqlColumnType],
3155 col_name: &str,
3156 ) -> MirScalarExpr {
3157 let list = self
3158 .expr
3159 .clone()
3160 // extract the list within the record
3161 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
3162
3163 // extract the expression within the list
3164 let record = MirScalarExpr::CallVariadic {
3165 func: VariadicFunc::ListIndex,
3166 exprs: vec![
3167 list,
3168 MirScalarExpr::literal_ok(Datum::Int64(1), SqlScalarType::Int64),
3169 ],
3170 };
3171
3172 MirScalarExpr::CallVariadic {
3173 func: VariadicFunc::ListCreate {
3174 elem_type: self
3175 .typ(input_type)
3176 .scalar_type
3177 .unwrap_list_element_type()
3178 .clone(),
3179 },
3180 exprs: vec![MirScalarExpr::CallVariadic {
3181 func: VariadicFunc::RecordCreate {
3182 field_names: vec![ColumnName::from(col_name), ColumnName::from("?record?")],
3183 },
3184 exprs: vec![
3185 MirScalarExpr::literal_ok(Datum::Int64(1), SqlScalarType::Int64),
3186 record,
3187 ],
3188 }],
3189 }
3190 }
3191
3192 /// `on_unique` for `lag` and `lead`
3193 fn on_unique_lag_lead(
3194 lag_lead: &LagLeadType,
3195 encoded_args: MirScalarExpr,
3196 return_type: SqlScalarType,
3197 ) -> (MirScalarExpr, ColumnName) {
3198 let expr = encoded_args
3199 .clone()
3200 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
3201 let offset = encoded_args
3202 .clone()
3203 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
3204 let default_value =
3205 encoded_args.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(2)));
3206
3207 // In this case, the window always has only one element, so if the offset is not null and
3208 // not zero, the default value should be returned instead.
3209 let value = offset
3210 .clone()
3211 .call_binary(
3212 MirScalarExpr::literal_ok(Datum::Int32(0), SqlScalarType::Int32),
3213 crate::func::Eq,
3214 )
3215 .if_then_else(expr, default_value);
3216 let result_expr = offset
3217 .call_unary(UnaryFunc::IsNull(crate::func::IsNull))
3218 .if_then_else(MirScalarExpr::literal_null(return_type), value);
3219
3220 let column_name = ColumnName::from(match lag_lead {
3221 LagLeadType::Lag => "?lag?",
3222 LagLeadType::Lead => "?lead?",
3223 });
3224
3225 (result_expr, column_name)
3226 }
3227
3228 /// `on_unique` for `first_value` and `last_value`
3229 fn on_unique_first_value_last_value(
3230 window_frame: &WindowFrame,
3231 arg: MirScalarExpr,
3232 return_type: SqlScalarType,
3233 ) -> (MirScalarExpr, ColumnName) {
3234 // If the window frame includes the current (single) row, return its value, null otherwise
3235 let result_expr = if window_frame.includes_current_row() {
3236 arg
3237 } else {
3238 MirScalarExpr::literal_null(return_type)
3239 };
3240 (result_expr, ColumnName::from("?first_value?"))
3241 }
3242
3243 /// `on_unique` for window aggregations
3244 fn on_unique_window_agg(
3245 window_frame: &WindowFrame,
3246 arg_expr: MirScalarExpr,
3247 input_type: &[SqlColumnType],
3248 return_type: SqlScalarType,
3249 wrapped_aggr: &AggregateFunc,
3250 ) -> (MirScalarExpr, ColumnName) {
3251 // If the window frame includes the current (single) row, evaluate the wrapped aggregate on
3252 // that row. Otherwise, return the default value for the aggregate.
3253 let result_expr = if window_frame.includes_current_row() {
3254 AggregateExpr {
3255 func: wrapped_aggr.clone(),
3256 expr: arg_expr,
3257 distinct: false, // We have just one input element; DISTINCT doesn't matter.
3258 }
3259 .on_unique(input_type)
3260 } else {
3261 MirScalarExpr::literal_ok(wrapped_aggr.default(), return_type)
3262 };
3263 (result_expr, ColumnName::from("?window_agg?"))
3264 }
3265
3266 /// Returns whether the expression is COUNT(*) or not. Note that
3267 /// when we define the count builtin in sql::func, we convert
3268 /// COUNT(*) to COUNT(true), making it indistinguishable from
3269 /// literal COUNT(true), but we prefer to consider this as the
3270 /// former.
3271 ///
3272 /// (HIR has the same `is_count_asterisk`.)
3273 pub fn is_count_asterisk(&self) -> bool {
3274 self.func == AggregateFunc::Count && self.expr.is_literal_true() && !self.distinct
3275 }
3276}
3277
3278/// Describe a join implementation in dataflow.
3279#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize, Hash, MzReflect)]
3280pub enum JoinImplementation {
3281 /// Perform a sequence of binary differential dataflow joins.
3282 ///
3283 /// The first argument indicates
3284 /// 1) the index of the starting collection,
3285 /// 2) if it should be arranged, the keys to arrange it by, and
3286 /// 3) the characteristics of the starting collection (for EXPLAINing).
3287 /// The sequence that follows lists other relation indexes, and the key for
3288 /// the arrangement we should use when joining it in.
3289 /// The JoinInputCharacteristics are for EXPLAINing the characteristics that
3290 /// were used for join ordering.
3291 ///
3292 /// Each collection index should occur exactly once, either as the starting collection
3293 /// or somewhere in the list.
3294 Differential(
3295 (
3296 usize,
3297 Option<Vec<MirScalarExpr>>,
3298 Option<JoinInputCharacteristics>,
3299 ),
3300 Vec<(usize, Vec<MirScalarExpr>, Option<JoinInputCharacteristics>)>,
3301 ),
3302 /// Perform independent delta query dataflows for each input.
3303 ///
3304 /// The argument is a sequence of plans, for the input collections in order.
3305 /// Each plan starts from the corresponding index, and then in sequence joins
3306 /// against collections identified by index and with the specified arrangement key.
3307 /// The JoinInputCharacteristics are for EXPLAINing the characteristics that were
3308 /// used for join ordering.
3309 DeltaQuery(Vec<Vec<(usize, Vec<MirScalarExpr>, Option<JoinInputCharacteristics>)>>),
3310 /// Join a user-created index with a constant collection to speed up the evaluation of a
3311 /// predicate such as `(f1 = 3 AND f2 = 5) OR (f1 = 7 AND f2 = 9)`.
3312 /// This gets translated to a Differential join during MIR -> LIR lowering, but we still want
3313 /// to represent it in MIR, because the fast path detection wants to match on this.
3314 ///
3315 /// Consists of (`<coll_id>`, `<index_id>`, `<index_key>`, `<constants>`)
3316 IndexedFilter(
3317 GlobalId,
3318 GlobalId,
3319 Vec<MirScalarExpr>,
3320 #[mzreflect(ignore)] Vec<Row>,
3321 ),
3322 /// No implementation yet selected.
3323 Unimplemented,
3324}
3325
3326impl Default for JoinImplementation {
3327 fn default() -> Self {
3328 JoinImplementation::Unimplemented
3329 }
3330}
3331
3332impl JoinImplementation {
3333 /// Returns `true` iff the value is not [`JoinImplementation::Unimplemented`].
3334 pub fn is_implemented(&self) -> bool {
3335 match self {
3336 Self::Unimplemented => false,
3337 _ => true,
3338 }
3339 }
3340
3341 /// Returns an optional implementation name if the value is not [`JoinImplementation::Unimplemented`].
3342 pub fn name(&self) -> Option<&'static str> {
3343 match self {
3344 Self::Differential(..) => Some("differential"),
3345 Self::DeltaQuery(..) => Some("delta"),
3346 Self::IndexedFilter(..) => Some("indexed_filter"),
3347 Self::Unimplemented => None,
3348 }
3349 }
3350}
3351
3352/// Characteristics of a join order candidate collection.
3353///
3354/// A candidate is described by a collection and a key, and may have various liabilities.
3355/// Primarily, the candidate may risk substantial inflation of records, which is something
3356/// that concerns us greatly. Additionally, the candidate may be unarranged, and we would
3357/// prefer candidates that do not require additional memory. Finally, we prefer lower id
3358/// collections in the interest of consistent tie-breaking. For more characteristics, see
3359/// comments on individual fields.
3360///
3361/// This has more than one version. `new` instantiates the appropriate version based on a
3362/// feature flag.
3363#[derive(Eq, PartialEq, Ord, PartialOrd, Debug, Clone, Serialize, Deserialize, Hash, MzReflect)]
3364pub enum JoinInputCharacteristics {
3365 /// Old version, with `enable_join_prioritize_arranged` turned off.
3366 V1(JoinInputCharacteristicsV1),
3367 /// Newer version, with `enable_join_prioritize_arranged` turned on.
3368 V2(JoinInputCharacteristicsV2),
3369}
3370
3371impl JoinInputCharacteristics {
3372 /// Creates a new instance with the given characteristics.
3373 pub fn new(
3374 unique_key: bool,
3375 key_length: usize,
3376 arranged: bool,
3377 cardinality: Option<usize>,
3378 filters: FilterCharacteristics,
3379 input: usize,
3380 enable_join_prioritize_arranged: bool,
3381 ) -> Self {
3382 if enable_join_prioritize_arranged {
3383 Self::V2(JoinInputCharacteristicsV2::new(
3384 unique_key,
3385 key_length,
3386 arranged,
3387 cardinality,
3388 filters,
3389 input,
3390 ))
3391 } else {
3392 Self::V1(JoinInputCharacteristicsV1::new(
3393 unique_key,
3394 key_length,
3395 arranged,
3396 cardinality,
3397 filters,
3398 input,
3399 ))
3400 }
3401 }
3402
3403 /// Turns the instance into a String to be printed in EXPLAIN.
3404 pub fn explain(&self) -> String {
3405 match self {
3406 Self::V1(jic) => jic.explain(),
3407 Self::V2(jic) => jic.explain(),
3408 }
3409 }
3410
3411 /// Whether the join input described by `self` is arranged.
3412 pub fn arranged(&self) -> bool {
3413 match self {
3414 Self::V1(jic) => jic.arranged,
3415 Self::V2(jic) => jic.arranged,
3416 }
3417 }
3418
3419 /// Returns the `FilterCharacteristics` for the join input described by `self`.
3420 pub fn filters(&mut self) -> &mut FilterCharacteristics {
3421 match self {
3422 Self::V1(jic) => &mut jic.filters,
3423 Self::V2(jic) => &mut jic.filters,
3424 }
3425 }
3426}
3427
3428/// Newer version of `JoinInputCharacteristics`, with `enable_join_prioritize_arranged` turned on.
3429#[derive(Eq, PartialEq, Ord, PartialOrd, Debug, Clone, Serialize, Deserialize, Hash, MzReflect)]
3430pub struct JoinInputCharacteristicsV2 {
3431 /// An excellent indication that record count will not increase.
3432 pub unique_key: bool,
3433 /// Cross joins are bad.
3434 /// (`key_length > 0` also implies that it is not a cross join. However, we need to note cross
3435 /// joins in a separate field, because not being a cross join is more important than `arranged`,
3436 /// but otherwise `key_length` is less important than `arranged`.)
3437 pub not_cross: bool,
3438 /// Indicates that there will be no additional in-memory footprint.
3439 pub arranged: bool,
3440 /// A weaker signal that record count will not increase.
3441 pub key_length: usize,
3442 /// Estimated cardinality (lower is better)
3443 pub cardinality: Option<std::cmp::Reverse<usize>>,
3444 /// Characteristics of the filter that is applied at this input.
3445 pub filters: FilterCharacteristics,
3446 /// We want to prefer input earlier in the input list, for stability of ordering.
3447 pub input: std::cmp::Reverse<usize>,
3448}
3449
3450impl JoinInputCharacteristicsV2 {
3451 /// Creates a new instance with the given characteristics.
3452 pub fn new(
3453 unique_key: bool,
3454 key_length: usize,
3455 arranged: bool,
3456 cardinality: Option<usize>,
3457 filters: FilterCharacteristics,
3458 input: usize,
3459 ) -> Self {
3460 Self {
3461 unique_key,
3462 not_cross: key_length > 0,
3463 arranged,
3464 key_length,
3465 cardinality: cardinality.map(std::cmp::Reverse),
3466 filters,
3467 input: std::cmp::Reverse(input),
3468 }
3469 }
3470
3471 /// Turns the instance into a String to be printed in EXPLAIN.
3472 pub fn explain(&self) -> String {
3473 let mut e = "".to_owned();
3474 if self.unique_key {
3475 e.push_str("U");
3476 }
3477 // Don't need to print `not_cross`, because that is visible in the printed key.
3478 // if !self.not_cross {
3479 // e.push_str("C");
3480 // }
3481 for _ in 0..self.key_length {
3482 e.push_str("K");
3483 }
3484 if self.arranged {
3485 e.push_str("A");
3486 }
3487 if let Some(std::cmp::Reverse(cardinality)) = self.cardinality {
3488 e.push_str(&format!("|{cardinality}|"));
3489 }
3490 e.push_str(&self.filters.explain());
3491 e
3492 }
3493}
3494
3495/// Old version of `JoinInputCharacteristics`, with `enable_join_prioritize_arranged` turned off.
3496#[derive(Eq, PartialEq, Ord, PartialOrd, Debug, Clone, Serialize, Deserialize, Hash, MzReflect)]
3497pub struct JoinInputCharacteristicsV1 {
3498 /// An excellent indication that record count will not increase.
3499 pub unique_key: bool,
3500 /// A weaker signal that record count will not increase.
3501 pub key_length: usize,
3502 /// Indicates that there will be no additional in-memory footprint.
3503 pub arranged: bool,
3504 /// Estimated cardinality (lower is better)
3505 pub cardinality: Option<std::cmp::Reverse<usize>>,
3506 /// Characteristics of the filter that is applied at this input.
3507 pub filters: FilterCharacteristics,
3508 /// We want to prefer input earlier in the input list, for stability of ordering.
3509 pub input: std::cmp::Reverse<usize>,
3510}
3511
3512impl JoinInputCharacteristicsV1 {
3513 /// Creates a new instance with the given characteristics.
3514 pub fn new(
3515 unique_key: bool,
3516 key_length: usize,
3517 arranged: bool,
3518 cardinality: Option<usize>,
3519 filters: FilterCharacteristics,
3520 input: usize,
3521 ) -> Self {
3522 Self {
3523 unique_key,
3524 key_length,
3525 arranged,
3526 cardinality: cardinality.map(std::cmp::Reverse),
3527 filters,
3528 input: std::cmp::Reverse(input),
3529 }
3530 }
3531
3532 /// Turns the instance into a String to be printed in EXPLAIN.
3533 pub fn explain(&self) -> String {
3534 let mut e = "".to_owned();
3535 if self.unique_key {
3536 e.push_str("U");
3537 }
3538 for _ in 0..self.key_length {
3539 e.push_str("K");
3540 }
3541 if self.arranged {
3542 e.push_str("A");
3543 }
3544 if let Some(std::cmp::Reverse(cardinality)) = self.cardinality {
3545 e.push_str(&format!("|{cardinality}|"));
3546 }
3547 e.push_str(&self.filters.explain());
3548 e
3549 }
3550}
3551
3552/// Instructions for finishing the result of a query.
3553///
3554/// The primary reason for the existence of this structure and attendant code
3555/// is that SQL's ORDER BY requires sorting rows (as already implied by the
3556/// keywords), whereas much of the rest of SQL is defined in terms of unordered
3557/// multisets. But as it turns out, the same idea can be used to optimize
3558/// trivial peeks.
3559///
3560/// The generic parameters are for accommodating prepared statement parameters in
3561/// `limit` and `offset`: the planner can hold these fields as HirScalarExpr long enough to call
3562/// `bind_parameters` on them.
3563#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
3564pub struct RowSetFinishing<L = NonNeg<i64>, O = usize> {
3565 /// Order rows by the given columns.
3566 pub order_by: Vec<ColumnOrder>,
3567 /// Include only as many rows (after offset).
3568 pub limit: Option<L>,
3569 /// Omit as many rows.
3570 pub offset: O,
3571 /// Include only given columns.
3572 pub project: Vec<usize>,
3573}
3574
3575impl<L> RowSetFinishing<L> {
3576 /// Returns a trivial finishing, i.e., that does nothing to the result set.
3577 pub fn trivial(arity: usize) -> RowSetFinishing<L> {
3578 RowSetFinishing {
3579 order_by: Vec::new(),
3580 limit: None,
3581 offset: 0,
3582 project: (0..arity).collect(),
3583 }
3584 }
3585 /// True if the finishing does nothing to any result set.
3586 pub fn is_trivial(&self, arity: usize) -> bool {
3587 self.limit.is_none()
3588 && self.order_by.is_empty()
3589 && self.offset == 0
3590 && self.project.iter().copied().eq(0..arity)
3591 }
3592 /// True if the finishing does not require an ORDER BY.
3593 ///
3594 /// LIMIT and OFFSET without an ORDER BY _are_ streamable: without an
3595 /// explicit ordering we will skip an arbitrary bag of elements and return
3596 /// the first arbitrary elements in the remaining bag. The result semantics
3597 /// are still correct but maybe surprising for some users.
3598 pub fn is_streamable(&self, arity: usize) -> bool {
3599 self.order_by.is_empty() && self.project.iter().copied().eq(0..arity)
3600 }
3601}
3602
3603impl RowSetFinishing<NonNeg<i64>, usize> {
3604 /// The number of rows needed from before the finishing to evaluate the finishing:
3605 /// offset + limit.
3606 ///
3607 /// If it returns None, then we need all the rows.
3608 pub fn num_rows_needed(&self) -> Option<usize> {
3609 self.limit
3610 .as_ref()
3611 .map(|l| usize::cast_from(u64::from(l.clone())) + self.offset)
3612 }
3613}
3614
3615impl RowSetFinishing {
3616 /// Applies finishing actions to a [`RowCollection`], and reports the total
3617 /// time it took to run.
3618 ///
3619 /// Returns a [`SortedRowCollectionIter`] that contains all of the response data, as
3620 /// well as the size of the response in bytes.
3621 pub fn finish(
3622 &self,
3623 rows: RowCollection,
3624 max_result_size: u64,
3625 max_returned_query_size: Option<u64>,
3626 duration_histogram: &Histogram,
3627 ) -> Result<(SortedRowCollectionIter, usize), String> {
3628 let now = Instant::now();
3629 let result = self.finish_inner(rows, max_result_size, max_returned_query_size);
3630 let duration = now.elapsed();
3631 duration_histogram.observe(duration.as_secs_f64());
3632
3633 result
3634 }
3635
3636 /// Implementation for [`RowSetFinishing::finish`].
3637 fn finish_inner(
3638 &self,
3639 rows: RowCollection,
3640 max_result_size: u64,
3641 max_returned_query_size: Option<u64>,
3642 ) -> Result<(SortedRowCollectionIter, usize), String> {
3643 // How much additional memory is required to make a sorted view.
3644 let sorted_view_mem = rows.entries().saturating_mul(std::mem::size_of::<usize>());
3645 let required_memory = rows.byte_len().saturating_add(sorted_view_mem);
3646
3647 // Bail if creating the sorted view would require us to use too much memory.
3648 if required_memory > usize::cast_from(max_result_size) {
3649 let max_bytes = ByteSize::b(max_result_size);
3650 return Err(format!("result exceeds max size of {max_bytes}",));
3651 }
3652
3653 let sorted_view = rows.sorted_view(&self.order_by);
3654 let mut iter = sorted_view
3655 .into_row_iter()
3656 .apply_offset(self.offset)
3657 .with_projection(self.project.clone());
3658
3659 if let Some(limit) = self.limit {
3660 let limit = u64::from(limit);
3661 let limit = usize::cast_from(limit);
3662 iter = iter.with_limit(limit);
3663 };
3664
3665 // TODO(parkmycar): Re-think how we can calculate the total response size without
3666 // having to iterate through the entire collection of Rows, while still
3667 // respecting the LIMIT, OFFSET, and projections.
3668 //
3669 // Note: It feels a bit bad always calculating the response size, but we almost
3670 // always need it to either check the `max_returned_query_size`, or for reporting
3671 // in the query history.
3672 let response_size: usize = iter.clone().map(|row| row.data().len()).sum();
3673
3674 // Bail if we would end up returning more data to the client than they can support.
3675 if let Some(max) = max_returned_query_size {
3676 if response_size > usize::cast_from(max) {
3677 let max_bytes = ByteSize::b(max);
3678 return Err(format!("result exceeds max size of {max_bytes}"));
3679 }
3680 }
3681
3682 Ok((iter, response_size))
3683 }
3684}
3685
3686/// A [RowSetFinishing] that can be repeatedly applied to batches of updates (in
3687/// a [RowCollection]) and keeps track of the remaining limit, offset, and cap
3688/// on query result size.
3689#[derive(Debug)]
3690pub struct RowSetFinishingIncremental {
3691 /// Include only as many rows (after offset).
3692 pub remaining_limit: Option<usize>,
3693 /// Omit as many rows.
3694 pub remaining_offset: usize,
3695 /// The maximum allowed result size, as requested by the client.
3696 pub max_returned_query_size: Option<u64>,
3697 /// Tracks our remaining allowed budget for result size.
3698 pub remaining_max_returned_query_size: Option<u64>,
3699 /// Include only given columns.
3700 pub project: Vec<usize>,
3701}
3702
3703impl RowSetFinishingIncremental {
3704 /// Turns the given [RowSetFinishing] into a [RowSetFinishingIncremental].
3705 /// Can only be used when [is_streamable](RowSetFinishing::is_streamable) is
3706 /// `true`.
3707 ///
3708 /// # Panics
3709 ///
3710 /// Panics if the result is not streamable, that is it has an ORDER BY.
3711 pub fn new(
3712 offset: usize,
3713 limit: Option<NonNeg<i64>>,
3714 project: Vec<usize>,
3715 max_returned_query_size: Option<u64>,
3716 ) -> Self {
3717 let limit = limit.map(|l| {
3718 let l = u64::from(l);
3719 let l = usize::cast_from(l);
3720 l
3721 });
3722
3723 RowSetFinishingIncremental {
3724 remaining_limit: limit,
3725 remaining_offset: offset,
3726 max_returned_query_size,
3727 remaining_max_returned_query_size: max_returned_query_size,
3728 project,
3729 }
3730 }
3731
3732 /// Applies finishing actions to the given [`RowCollection`], and reports
3733 /// the total time it took to run.
3734 ///
3735 /// Returns a [`SortedRowCollectionIter`] that contains all of the response
3736 /// data.
3737 pub fn finish_incremental(
3738 &mut self,
3739 rows: RowCollection,
3740 max_result_size: u64,
3741 duration_histogram: &Histogram,
3742 ) -> Result<SortedRowCollectionIter, String> {
3743 let now = Instant::now();
3744 let result = self.finish_incremental_inner(rows, max_result_size);
3745 let duration = now.elapsed();
3746 duration_histogram.observe(duration.as_secs_f64());
3747
3748 result
3749 }
3750
3751 fn finish_incremental_inner(
3752 &mut self,
3753 rows: RowCollection,
3754 max_result_size: u64,
3755 ) -> Result<SortedRowCollectionIter, String> {
3756 // How much additional memory is required to make a sorted view.
3757 let sorted_view_mem = rows.entries().saturating_mul(std::mem::size_of::<usize>());
3758 let required_memory = rows.byte_len().saturating_add(sorted_view_mem);
3759
3760 // Bail if creating the sorted view would require us to use too much memory.
3761 if required_memory > usize::cast_from(max_result_size) {
3762 let max_bytes = ByteSize::b(max_result_size);
3763 return Err(format!("total result exceeds max size of {max_bytes}",));
3764 }
3765
3766 let batch_num_rows = rows.count(0, None);
3767
3768 let sorted_view = rows.sorted_view(&[]);
3769 let mut iter = sorted_view
3770 .into_row_iter()
3771 .apply_offset(self.remaining_offset)
3772 .with_projection(self.project.clone());
3773
3774 if let Some(limit) = self.remaining_limit {
3775 iter = iter.with_limit(limit);
3776 };
3777
3778 self.remaining_offset = self.remaining_offset.saturating_sub(batch_num_rows);
3779 if let Some(remaining_limit) = self.remaining_limit.as_mut() {
3780 *remaining_limit -= iter.count();
3781 }
3782
3783 // TODO(parkmycar): Re-think how we can calculate the total response size without
3784 // having to iterate through the entire collection of Rows, while still
3785 // respecting the LIMIT, OFFSET, and projections.
3786 //
3787 // Note: It feels a bit bad always calculating the response size, but we almost
3788 // always need it to either check the `max_returned_query_size`, or for reporting
3789 // in the query history.
3790 let response_size: usize = iter.clone().map(|row| row.data().len()).sum();
3791
3792 // Bail if we would end up returning more data to the client than they can support.
3793 if let Some(max) = self.remaining_max_returned_query_size {
3794 if response_size > usize::cast_from(max) {
3795 let max_bytes = ByteSize::b(self.max_returned_query_size.expect("known to exist"));
3796 return Err(format!("total result exceeds max size of {max_bytes}"));
3797 }
3798 }
3799
3800 Ok(iter)
3801 }
3802}
3803
3804/// Compare `left` and `right` using `order`. If that doesn't produce a strict
3805/// ordering, call `tiebreaker`.
3806pub fn compare_columns<F>(
3807 order: &[ColumnOrder],
3808 left: &[Datum],
3809 right: &[Datum],
3810 tiebreaker: F,
3811) -> Ordering
3812where
3813 F: Fn() -> Ordering,
3814{
3815 for order in order {
3816 let cmp = match (&left[order.column], &right[order.column]) {
3817 (Datum::Null, Datum::Null) => Ordering::Equal,
3818 (Datum::Null, _) => {
3819 if order.nulls_last {
3820 Ordering::Greater
3821 } else {
3822 Ordering::Less
3823 }
3824 }
3825 (_, Datum::Null) => {
3826 if order.nulls_last {
3827 Ordering::Less
3828 } else {
3829 Ordering::Greater
3830 }
3831 }
3832 (lval, rval) => {
3833 if order.desc {
3834 rval.cmp(lval)
3835 } else {
3836 lval.cmp(rval)
3837 }
3838 }
3839 };
3840 if cmp != Ordering::Equal {
3841 return cmp;
3842 }
3843 }
3844 tiebreaker()
3845}
3846
3847/// Describe a window frame, e.g. `RANGE UNBOUNDED PRECEDING` or
3848/// `ROWS BETWEEN 5 PRECEDING AND CURRENT ROW`.
3849///
3850/// Window frames define a subset of the partition , and only a subset of
3851/// window functions make use of the window frame.
3852#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize, Hash, MzReflect)]
3853pub struct WindowFrame {
3854 /// ROWS, RANGE or GROUPS
3855 pub units: WindowFrameUnits,
3856 /// Where the frame starts
3857 pub start_bound: WindowFrameBound,
3858 /// Where the frame ends
3859 pub end_bound: WindowFrameBound,
3860}
3861
3862impl Display for WindowFrame {
3863 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
3864 write!(
3865 f,
3866 "{} between {} and {}",
3867 self.units, self.start_bound, self.end_bound
3868 )
3869 }
3870}
3871
3872impl WindowFrame {
3873 /// Return the default window frame used when one is not explicitly defined
3874 pub fn default() -> Self {
3875 WindowFrame {
3876 units: WindowFrameUnits::Range,
3877 start_bound: WindowFrameBound::UnboundedPreceding,
3878 end_bound: WindowFrameBound::CurrentRow,
3879 }
3880 }
3881
3882 fn includes_current_row(&self) -> bool {
3883 use WindowFrameBound::*;
3884 match self.start_bound {
3885 UnboundedPreceding => match self.end_bound {
3886 UnboundedPreceding => false,
3887 OffsetPreceding(0) => true,
3888 OffsetPreceding(_) => false,
3889 CurrentRow => true,
3890 OffsetFollowing(_) => true,
3891 UnboundedFollowing => true,
3892 },
3893 OffsetPreceding(0) => match self.end_bound {
3894 UnboundedPreceding => unreachable!(),
3895 OffsetPreceding(0) => true,
3896 // Any nonzero offsets here will create an empty window
3897 OffsetPreceding(_) => false,
3898 CurrentRow => true,
3899 OffsetFollowing(_) => true,
3900 UnboundedFollowing => true,
3901 },
3902 OffsetPreceding(_) => match self.end_bound {
3903 UnboundedPreceding => unreachable!(),
3904 // Window ends at the current row
3905 OffsetPreceding(0) => true,
3906 OffsetPreceding(_) => false,
3907 CurrentRow => true,
3908 OffsetFollowing(_) => true,
3909 UnboundedFollowing => true,
3910 },
3911 CurrentRow => true,
3912 OffsetFollowing(0) => match self.end_bound {
3913 UnboundedPreceding => unreachable!(),
3914 OffsetPreceding(_) => unreachable!(),
3915 CurrentRow => unreachable!(),
3916 OffsetFollowing(_) => true,
3917 UnboundedFollowing => true,
3918 },
3919 OffsetFollowing(_) => match self.end_bound {
3920 UnboundedPreceding => unreachable!(),
3921 OffsetPreceding(_) => unreachable!(),
3922 CurrentRow => unreachable!(),
3923 OffsetFollowing(_) => false,
3924 UnboundedFollowing => false,
3925 },
3926 UnboundedFollowing => false,
3927 }
3928 }
3929}
3930
3931/// Describe how frame bounds are interpreted
3932#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize, Hash, MzReflect)]
3933pub enum WindowFrameUnits {
3934 /// Each row is treated as the unit of work for bounds
3935 Rows,
3936 /// Each peer group is treated as the unit of work for bounds,
3937 /// and offset-based bounds use the value of the ORDER BY expression
3938 Range,
3939 /// Each peer group is treated as the unit of work for bounds.
3940 /// Groups is currently not supported, and it is rejected during planning.
3941 Groups,
3942}
3943
3944impl Display for WindowFrameUnits {
3945 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
3946 match self {
3947 WindowFrameUnits::Rows => write!(f, "rows"),
3948 WindowFrameUnits::Range => write!(f, "range"),
3949 WindowFrameUnits::Groups => write!(f, "groups"),
3950 }
3951 }
3952}
3953
3954/// Specifies [WindowFrame]'s `start_bound` and `end_bound`
3955///
3956/// The order between frame bounds is significant, as Postgres enforces
3957/// some restrictions there.
3958#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash, MzReflect, PartialOrd, Ord)]
3959pub enum WindowFrameBound {
3960 /// `UNBOUNDED PRECEDING`
3961 UnboundedPreceding,
3962 /// `<N> PRECEDING`
3963 OffsetPreceding(u64),
3964 /// `CURRENT ROW`
3965 CurrentRow,
3966 /// `<N> FOLLOWING`
3967 OffsetFollowing(u64),
3968 /// `UNBOUNDED FOLLOWING`.
3969 UnboundedFollowing,
3970}
3971
3972impl Display for WindowFrameBound {
3973 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
3974 match self {
3975 WindowFrameBound::UnboundedPreceding => write!(f, "unbounded preceding"),
3976 WindowFrameBound::OffsetPreceding(offset) => write!(f, "{} preceding", offset),
3977 WindowFrameBound::CurrentRow => write!(f, "current row"),
3978 WindowFrameBound::OffsetFollowing(offset) => write!(f, "{} following", offset),
3979 WindowFrameBound::UnboundedFollowing => write!(f, "unbounded following"),
3980 }
3981 }
3982}
3983
3984/// Maximum iterations for a LetRec.
3985#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
3986pub struct LetRecLimit {
3987 /// Maximum number of iterations to evaluate.
3988 pub max_iters: NonZeroU64,
3989 /// Whether to throw an error when reaching the above limit.
3990 /// If true, we simply use the current contents of each Id as the final result.
3991 pub return_at_limit: bool,
3992}
3993
3994impl LetRecLimit {
3995 /// Compute the smallest limit from a Vec of `LetRecLimit`s.
3996 pub fn min_max_iter(limits: &Vec<Option<LetRecLimit>>) -> Option<u64> {
3997 limits
3998 .iter()
3999 .filter_map(|l| l.as_ref().map(|l| l.max_iters.get()))
4000 .min()
4001 }
4002
4003 /// The default value of `LetRecLimit::return_at_limit` when using the RECURSION LIMIT option of
4004 /// WMR without ERROR AT or RETURN AT.
4005 pub const RETURN_AT_LIMIT_DEFAULT: bool = false;
4006}
4007
4008impl Display for LetRecLimit {
4009 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
4010 write!(f, "[recursion_limit={}", self.max_iters)?;
4011 if self.return_at_limit != LetRecLimit::RETURN_AT_LIMIT_DEFAULT {
4012 write!(f, ", return_at_limit")?;
4013 }
4014 write!(f, "]")
4015 }
4016}
4017
4018/// For a global Get, this indicates whether we are going to read from Persist or from an index.
4019/// (See comment in MirRelationExpr::Get.)
4020#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize, Hash)]
4021pub enum AccessStrategy {
4022 /// It's either a local Get (a CTE), or unknown at the time.
4023 /// `prune_and_annotate_dataflow_index_imports` decides it for global Gets, and thus switches to
4024 /// one of the other variants.
4025 UnknownOrLocal,
4026 /// The Get will read from Persist.
4027 Persist,
4028 /// The Get will read from an index or indexes: (index id, how the index will be used).
4029 Index(Vec<(GlobalId, IndexUsageType)>),
4030 /// The Get will read a collection that is computed by the same dataflow, but in a different
4031 /// `BuildDesc` in `objects_to_build`.
4032 SameDataflow,
4033}
4034
4035#[cfg(test)]
4036mod tests {
4037 use mz_repr::explain::text::text_string_at;
4038
4039 use crate::explain::HumanizedExplain;
4040
4041 use super::*;
4042
4043 #[mz_ore::test]
4044 fn test_row_set_finishing_as_text() {
4045 let finishing = RowSetFinishing {
4046 order_by: vec![ColumnOrder {
4047 column: 4,
4048 desc: true,
4049 nulls_last: true,
4050 }],
4051 limit: Some(NonNeg::try_from(7).unwrap()),
4052 offset: Default::default(),
4053 project: vec![1, 3, 4, 5],
4054 };
4055
4056 let mode = HumanizedExplain::new(false);
4057 let expr = mode.expr(&finishing, None);
4058
4059 let act = text_string_at(&expr, mz_ore::str::Indent::default);
4060
4061 let exp = {
4062 use mz_ore::fmt::FormatBuffer;
4063 let mut s = String::new();
4064 write!(&mut s, "Finish");
4065 write!(&mut s, " order_by=[#4 desc nulls_last]");
4066 write!(&mut s, " limit=7");
4067 write!(&mut s, " output=[#1, #3..=#5]");
4068 writeln!(&mut s, "");
4069 s
4070 };
4071
4072 assert_eq!(act, exp);
4073 }
4074}
4075
4076/// An iterator over AST structures, which calls out nodes in difference.
4077///
4078/// The iterators visit two ASTs in tandem, continuing as long as the AST node data matches,
4079/// and yielding an output pair as soon as the AST nodes do not match. Their intent is to call
4080/// attention to the moments in the ASTs where they differ, and incidentally a stack-free way
4081/// to compare two ASTs.
4082mod structured_diff {
4083
4084 use super::MirRelationExpr;
4085 use itertools::Itertools;
4086
4087 /// An iterator over structured differences between two `MirRelationExpr` instances.
4088 pub struct MreDiff<'a> {
4089 /// Pairs of expressions that must still be compared.
4090 todo: Vec<(&'a MirRelationExpr, &'a MirRelationExpr)>,
4091 }
4092
4093 impl<'a> MreDiff<'a> {
4094 /// Create a new `MirRelationExpr` structured difference.
4095 pub fn new(expr1: &'a MirRelationExpr, expr2: &'a MirRelationExpr) -> Self {
4096 MreDiff {
4097 todo: vec![(expr1, expr2)],
4098 }
4099 }
4100 }
4101
4102 impl<'a> Iterator for MreDiff<'a> {
4103 // Pairs of expressions that do not match.
4104 type Item = (&'a MirRelationExpr, &'a MirRelationExpr);
4105
4106 fn next(&mut self) -> Option<Self::Item> {
4107 while let Some((expr1, expr2)) = self.todo.pop() {
4108 match (expr1, expr2) {
4109 (
4110 MirRelationExpr::Constant {
4111 rows: rows1,
4112 typ: typ1,
4113 },
4114 MirRelationExpr::Constant {
4115 rows: rows2,
4116 typ: typ2,
4117 },
4118 ) => {
4119 if rows1 != rows2 || typ1 != typ2 {
4120 return Some((expr1, expr2));
4121 }
4122 }
4123 (
4124 MirRelationExpr::Get {
4125 id: id1,
4126 typ: typ1,
4127 access_strategy: as1,
4128 },
4129 MirRelationExpr::Get {
4130 id: id2,
4131 typ: typ2,
4132 access_strategy: as2,
4133 },
4134 ) => {
4135 if id1 != id2 || typ1 != typ2 || as1 != as2 {
4136 return Some((expr1, expr2));
4137 }
4138 }
4139 (
4140 MirRelationExpr::Let {
4141 id: id1,
4142 body: body1,
4143 value: value1,
4144 },
4145 MirRelationExpr::Let {
4146 id: id2,
4147 body: body2,
4148 value: value2,
4149 },
4150 ) => {
4151 if id1 != id2 {
4152 return Some((expr1, expr2));
4153 } else {
4154 self.todo.push((body1, body2));
4155 self.todo.push((value1, value2));
4156 }
4157 }
4158 (
4159 MirRelationExpr::LetRec {
4160 ids: ids1,
4161 body: body1,
4162 values: values1,
4163 limits: limits1,
4164 },
4165 MirRelationExpr::LetRec {
4166 ids: ids2,
4167 body: body2,
4168 values: values2,
4169 limits: limits2,
4170 },
4171 ) => {
4172 if ids1 != ids2 || values1.len() != values2.len() || limits1 != limits2 {
4173 return Some((expr1, expr2));
4174 } else {
4175 self.todo.push((body1, body2));
4176 self.todo.extend(values1.iter().zip_eq(values2.iter()));
4177 }
4178 }
4179 (
4180 MirRelationExpr::Project {
4181 outputs: outputs1,
4182 input: input1,
4183 },
4184 MirRelationExpr::Project {
4185 outputs: outputs2,
4186 input: input2,
4187 },
4188 ) => {
4189 if outputs1 != outputs2 {
4190 return Some((expr1, expr2));
4191 } else {
4192 self.todo.push((input1, input2));
4193 }
4194 }
4195 (
4196 MirRelationExpr::Map {
4197 scalars: scalars1,
4198 input: input1,
4199 },
4200 MirRelationExpr::Map {
4201 scalars: scalars2,
4202 input: input2,
4203 },
4204 ) => {
4205 if scalars1 != scalars2 {
4206 return Some((expr1, expr2));
4207 } else {
4208 self.todo.push((input1, input2));
4209 }
4210 }
4211 (
4212 MirRelationExpr::Filter {
4213 predicates: predicates1,
4214 input: input1,
4215 },
4216 MirRelationExpr::Filter {
4217 predicates: predicates2,
4218 input: input2,
4219 },
4220 ) => {
4221 if predicates1 != predicates2 {
4222 return Some((expr1, expr2));
4223 } else {
4224 self.todo.push((input1, input2));
4225 }
4226 }
4227 (
4228 MirRelationExpr::FlatMap {
4229 input: input1,
4230 func: func1,
4231 exprs: exprs1,
4232 },
4233 MirRelationExpr::FlatMap {
4234 input: input2,
4235 func: func2,
4236 exprs: exprs2,
4237 },
4238 ) => {
4239 if func1 != func2 || exprs1 != exprs2 {
4240 return Some((expr1, expr2));
4241 } else {
4242 self.todo.push((input1, input2));
4243 }
4244 }
4245 (
4246 MirRelationExpr::Join {
4247 inputs: inputs1,
4248 equivalences: eq1,
4249 implementation: impl1,
4250 },
4251 MirRelationExpr::Join {
4252 inputs: inputs2,
4253 equivalences: eq2,
4254 implementation: impl2,
4255 },
4256 ) => {
4257 if inputs1.len() != inputs2.len() || eq1 != eq2 || impl1 != impl2 {
4258 return Some((expr1, expr2));
4259 } else {
4260 self.todo.extend(inputs1.iter().zip_eq(inputs2.iter()));
4261 }
4262 }
4263 (
4264 MirRelationExpr::Reduce {
4265 aggregates: aggregates1,
4266 input: inputs1,
4267 group_key: gk1,
4268 monotonic: m1,
4269 expected_group_size: egs1,
4270 },
4271 MirRelationExpr::Reduce {
4272 aggregates: aggregates2,
4273 input: inputs2,
4274 group_key: gk2,
4275 monotonic: m2,
4276 expected_group_size: egs2,
4277 },
4278 ) => {
4279 if aggregates1 != aggregates2 || gk1 != gk2 || m1 != m2 || egs1 != egs2 {
4280 return Some((expr1, expr2));
4281 } else {
4282 self.todo.push((inputs1, inputs2));
4283 }
4284 }
4285 (
4286 MirRelationExpr::TopK {
4287 group_key: gk1,
4288 order_key: order1,
4289 input: input1,
4290 limit: l1,
4291 offset: o1,
4292 monotonic: m1,
4293 expected_group_size: egs1,
4294 },
4295 MirRelationExpr::TopK {
4296 group_key: gk2,
4297 order_key: order2,
4298 input: input2,
4299 limit: l2,
4300 offset: o2,
4301 monotonic: m2,
4302 expected_group_size: egs2,
4303 },
4304 ) => {
4305 if order1 != order2
4306 || gk1 != gk2
4307 || l1 != l2
4308 || o1 != o2
4309 || m1 != m2
4310 || egs1 != egs2
4311 {
4312 return Some((expr1, expr2));
4313 } else {
4314 self.todo.push((input1, input2));
4315 }
4316 }
4317 (
4318 MirRelationExpr::Negate { input: input1 },
4319 MirRelationExpr::Negate { input: input2 },
4320 ) => {
4321 self.todo.push((input1, input2));
4322 }
4323 (
4324 MirRelationExpr::Threshold { input: input1 },
4325 MirRelationExpr::Threshold { input: input2 },
4326 ) => {
4327 self.todo.push((input1, input2));
4328 }
4329 (
4330 MirRelationExpr::Union {
4331 base: base1,
4332 inputs: inputs1,
4333 },
4334 MirRelationExpr::Union {
4335 base: base2,
4336 inputs: inputs2,
4337 },
4338 ) => {
4339 if inputs1.len() != inputs2.len() {
4340 return Some((expr1, expr2));
4341 } else {
4342 self.todo.push((base1, base2));
4343 self.todo.extend(inputs1.iter().zip_eq(inputs2.iter()));
4344 }
4345 }
4346 (
4347 MirRelationExpr::ArrangeBy {
4348 keys: keys1,
4349 input: input1,
4350 },
4351 MirRelationExpr::ArrangeBy {
4352 keys: keys2,
4353 input: input2,
4354 },
4355 ) => {
4356 if keys1 != keys2 {
4357 return Some((expr1, expr2));
4358 } else {
4359 self.todo.push((input1, input2));
4360 }
4361 }
4362 _ => {
4363 return Some((expr1, expr2));
4364 }
4365 }
4366 }
4367 None
4368 }
4369 }
4370}