mz_expr/relation.rs
1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10#![warn(missing_docs)]
11
12use std::cell::RefCell;
13use std::cmp::{Ordering, max};
14use std::collections::{BTreeMap, BTreeSet};
15use std::fmt;
16use std::fmt::{Display, Formatter};
17use std::hash::{DefaultHasher, Hash, Hasher};
18use std::num::NonZeroU64;
19use std::time::Instant;
20
21use bytesize::ByteSize;
22use columnation::{Columnation, CopyRegion};
23use itertools::Itertools;
24use mz_lowertest::MzReflect;
25use mz_ore::cast::{CastFrom, CastInto};
26use mz_ore::collections::CollectionExt;
27use mz_ore::id_gen::IdGen;
28use mz_ore::metrics::Histogram;
29use mz_ore::num::NonNeg;
30use mz_ore::soft_assert_no_log;
31use mz_ore::stack::RecursionLimitError;
32use mz_ore::str::Indent;
33use mz_repr::adt::numeric::NumericMaxScale;
34use mz_repr::explain::text::text_string_at;
35use mz_repr::explain::{
36 DummyHumanizer, ExplainConfig, ExprHumanizer, IndexUsageType, PlanRenderingContext,
37};
38use mz_repr::{
39 ColumnName, Datum, DatumVec, Diff, GlobalId, IntoRowIterator, ReprColumnType, ReprRelationType,
40 ReprScalarType, Row, RowIterator, RowRef, SqlColumnType, SqlRelationType, SqlScalarType,
41};
42use serde::{Deserialize, Serialize};
43
44use crate::Id::Local;
45use crate::explain::{HumanizedExpr, HumanizerMode};
46use crate::relation::func::{AggregateFunc, LagLeadType, TableFunc};
47use crate::row::{RowCollection, RowCollectionIter};
48use crate::scalar::func::variadic::{
49 JsonbBuildArray, JsonbBuildObject, ListCreate, ListIndex, MapBuild, RecordCreate,
50};
51use crate::visit::{Visit, VisitChildren};
52use crate::{
53 EvalError, FilterCharacteristics, Id, LocalId, MirScalarExpr, UnaryFunc, func as scalar_func,
54};
55
56pub mod canonicalize;
57pub mod func;
58pub mod join_input_mapper;
59
60/// A recursion limit to be used for stack-safe traversals of [`MirRelationExpr`] trees.
61///
62/// The recursion limit must be large enough to accommodate for the linear representation
63/// of some pathological but frequently occurring query fragments.
64///
65/// For example, in MIR we could have long chains of
66/// - (1) `Let` bindings,
67/// - (2) `CallBinary` calls with associative functions such as `+`
68///
69/// Until we fix those, we need to stick with the larger recursion limit.
70pub const RECURSION_LIMIT: usize = 2048;
71
72/// A trait for types that describe how to build a collection.
73pub trait CollectionPlan {
74 /// Collects the set of global identifiers from dataflows referenced in Get.
75 fn depends_on_into(&self, out: &mut BTreeSet<GlobalId>);
76
77 /// Returns the set of global identifiers from dataflows referenced in Get.
78 ///
79 /// See [`CollectionPlan::depends_on_into`] to reuse an existing `BTreeSet`.
80 fn depends_on(&self) -> BTreeSet<GlobalId> {
81 let mut out = BTreeSet::new();
82 self.depends_on_into(&mut out);
83 out
84 }
85}
86
87/// An abstract syntax tree which defines a collection.
88///
89/// The AST is meant to reflect the capabilities of the `differential_dataflow::Collection` type,
90/// written generically enough to avoid run-time compilation work.
91///
92/// `derived_hash_with_manual_eq` was complaining for the wrong reason: This lint exists because
93/// it's bad when `Eq` doesn't agree with `Hash`, which is often quite likely if one of them is
94/// implemented manually. However, our manual implementation of `Eq` _will_ agree with the derived
95/// one. This is because the reason for the manual implementation is not to change the semantics
96/// from the derived one, but to avoid stack overflows.
97#[allow(clippy::derived_hash_with_manual_eq)]
98#[derive(Clone, Debug, Ord, PartialOrd, Serialize, Deserialize, MzReflect, Hash)]
99pub enum MirRelationExpr {
100 /// A constant relation containing specified rows.
101 ///
102 /// The runtime memory footprint of this operator is zero.
103 ///
104 /// When you would like to pattern match on this, consider using `MirRelationExpr::as_const`
105 /// instead, which looks behind `ArrangeBy`s. You might want this matching behavior because
106 /// constant folding doesn't remove `ArrangeBy`s.
107 Constant {
108 /// Rows of the constant collection and their multiplicities.
109 rows: Result<Vec<(Row, Diff)>, EvalError>,
110 /// Schema of the collection.
111 typ: ReprRelationType,
112 },
113 /// Get an existing dataflow.
114 ///
115 /// The runtime memory footprint of this operator is zero.
116 Get {
117 /// The identifier for the collection to load.
118 #[mzreflect(ignore)]
119 id: Id,
120 /// Schema of the collection.
121 typ: ReprRelationType,
122 /// If this is a global Get, this will indicate whether we are going to read from Persist or
123 /// from an index, or from a different object in `objects_to_build`. If it's an index, then
124 /// how downstream dataflow operations will use this index is also recorded. This is filled
125 /// by `prune_and_annotate_dataflow_index_imports`. Note that this is not used by the
126 /// lowering to LIR, but is used only by EXPLAIN.
127 #[mzreflect(ignore)]
128 access_strategy: AccessStrategy,
129 },
130 /// Introduce a temporary dataflow.
131 ///
132 /// The runtime memory footprint of this operator is zero.
133 Let {
134 /// The identifier to be used in `Get` variants to retrieve `value`.
135 #[mzreflect(ignore)]
136 id: LocalId,
137 /// The collection to be bound to `id`.
138 value: Box<MirRelationExpr>,
139 /// The result of the `Let`, evaluated with `id` bound to `value`.
140 body: Box<MirRelationExpr>,
141 },
142 /// Introduce mutually recursive bindings.
143 ///
144 /// Each `LocalId` is immediately bound to an initially empty collection
145 /// with the type of its corresponding `MirRelationExpr`. Repeatedly, each
146 /// binding is evaluated using the current contents of each other binding,
147 /// and is refreshed to contain the new evaluation. This process continues
148 /// through all bindings, and repeats as long as changes continue to occur.
149 ///
150 /// The resulting value of the expression is `body` evaluated once in the
151 /// context of the final iterates.
152 ///
153 /// A zero-binding instance can be replaced by `body`.
154 /// A single-binding instance is equivalent to `MirRelationExpr::Let`.
155 ///
156 /// The runtime memory footprint of this operator is zero.
157 LetRec {
158 /// The identifiers to be used in `Get` variants to retrieve each `value`.
159 #[mzreflect(ignore)]
160 ids: Vec<LocalId>,
161 /// The collections to be bound to each `id`.
162 values: Vec<MirRelationExpr>,
163 /// Maximum number of iterations, after which we should artificially force a fixpoint.
164 /// (Whether we error or just stop is configured by `LetRecLimit::return_at_limit`.)
165 /// The per-`LetRec` limit that the user specified is initially copied to each binding to
166 /// accommodate slicing and merging of `LetRec`s in MIR transforms (e.g., `NormalizeLets`).
167 #[mzreflect(ignore)]
168 limits: Vec<Option<LetRecLimit>>,
169 /// The result of the `Let`, evaluated with `id` bound to `value`.
170 body: Box<MirRelationExpr>,
171 },
172 /// Project out some columns from a dataflow
173 ///
174 /// The runtime memory footprint of this operator is zero.
175 Project {
176 /// The source collection.
177 input: Box<MirRelationExpr>,
178 /// Indices of columns to retain.
179 outputs: Vec<usize>,
180 },
181 /// Append new columns to a dataflow
182 ///
183 /// The runtime memory footprint of this operator is zero.
184 Map {
185 /// The source collection.
186 input: Box<MirRelationExpr>,
187 /// Expressions which determine values to append to each row.
188 /// An expression may refer to columns in `input` or
189 /// expressions defined earlier in the vector
190 scalars: Vec<MirScalarExpr>,
191 },
192 /// Like Map, but yields zero-or-more output rows per input row
193 ///
194 /// The runtime memory footprint of this operator is zero.
195 FlatMap {
196 /// The source collection
197 input: Box<MirRelationExpr>,
198 /// The table func to apply
199 func: TableFunc,
200 /// The argument to the table func
201 exprs: Vec<MirScalarExpr>,
202 },
203 /// Keep rows from a dataflow where all the predicates are true
204 ///
205 /// The runtime memory footprint of this operator is zero.
206 Filter {
207 /// The source collection.
208 input: Box<MirRelationExpr>,
209 /// Predicates, each of which must be true.
210 predicates: Vec<MirScalarExpr>,
211 },
212 /// Join several collections, where some columns must be equal.
213 ///
214 /// For further details consult the documentation for [`MirRelationExpr::join`].
215 ///
216 /// The runtime memory footprint of this operator can be proportional to
217 /// the sizes of all inputs and the size of all joins of prefixes.
218 /// This may be reduced due to arrangements available at rendering time.
219 Join {
220 /// A sequence of input relations.
221 inputs: Vec<MirRelationExpr>,
222 /// A sequence of equivalence classes of expressions on the cross product of inputs.
223 ///
224 /// Each equivalence class is a list of scalar expressions, where for each class the
225 /// intended interpretation is that all evaluated expressions should be equal.
226 ///
227 /// Each scalar expression is to be evaluated over the cross-product of all records
228 /// from all inputs. In many cases this may just be column selection from specific
229 /// inputs, but more general cases exist (e.g. complex functions of multiple columns
230 /// from multiple inputs, or just constant literals).
231 equivalences: Vec<Vec<MirScalarExpr>>,
232 /// Join implementation information.
233 #[serde(default)]
234 implementation: JoinImplementation,
235 },
236 /// Group a dataflow by some columns and aggregate over each group
237 ///
238 /// The runtime memory footprint of this operator is at most proportional to the
239 /// number of distinct records in the input and output. The actual requirements
240 /// can be less: the number of distinct inputs to each aggregate, summed across
241 /// each aggregate, plus the output size. For more details consult the code that
242 /// builds the associated dataflow.
243 Reduce {
244 /// The source collection.
245 input: Box<MirRelationExpr>,
246 /// Column indices used to form groups.
247 group_key: Vec<MirScalarExpr>,
248 /// Expressions which determine values to append to each row, after the group keys.
249 aggregates: Vec<AggregateExpr>,
250 /// True iff the input is known to monotonically increase (only addition of records).
251 #[serde(default)]
252 monotonic: bool,
253 /// User hint: expected number of values per group key. Used to optimize physical rendering.
254 #[serde(default)]
255 expected_group_size: Option<u64>,
256 },
257 /// Groups and orders within each group, limiting output.
258 ///
259 /// The runtime memory footprint of this operator is proportional to its input and output.
260 TopK {
261 /// The source collection.
262 input: Box<MirRelationExpr>,
263 /// Column indices used to form groups.
264 group_key: Vec<usize>,
265 /// Column indices used to order rows within groups.
266 order_key: Vec<ColumnOrder>,
267 /// Number of records to retain
268 #[serde(default)]
269 limit: Option<MirScalarExpr>,
270 /// Number of records to skip
271 #[serde(default)]
272 offset: usize,
273 /// True iff the input is known to monotonically increase (only addition of records).
274 #[serde(default)]
275 monotonic: bool,
276 /// User-supplied hint: how many rows will have the same group key.
277 #[serde(default)]
278 expected_group_size: Option<u64>,
279 },
280 /// Return a dataflow where the row counts are negated
281 ///
282 /// The runtime memory footprint of this operator is zero.
283 Negate {
284 /// The source collection.
285 input: Box<MirRelationExpr>,
286 },
287 /// Keep rows from a dataflow where the row counts are positive
288 ///
289 /// The runtime memory footprint of this operator is proportional to its input and output.
290 Threshold {
291 /// The source collection.
292 input: Box<MirRelationExpr>,
293 },
294 /// Adds the frequencies of elements in contained sets.
295 ///
296 /// The runtime memory footprint of this operator is zero.
297 Union {
298 /// A source collection.
299 base: Box<MirRelationExpr>,
300 /// Source collections to union.
301 inputs: Vec<MirRelationExpr>,
302 },
303 /// Technically a no-op. Used to render an index. Will be used to optimize queries
304 /// on finer grain. Each `keys` item represents a different index that should be
305 /// produced from the `keys`.
306 ///
307 /// The runtime memory footprint of this operator is proportional to its input.
308 ArrangeBy {
309 /// The source collection
310 input: Box<MirRelationExpr>,
311 /// Columns to arrange `input` by, in order of decreasing primacy
312 keys: Vec<Vec<MirScalarExpr>>,
313 },
314}
315
316impl PartialEq for MirRelationExpr {
317 fn eq(&self, other: &Self) -> bool {
318 // Capture the result and test it wrt `Ord` implementation in test environments.
319 let result = structured_diff::MreDiff::new(self, other).next().is_none();
320 mz_ore::soft_assert_eq_no_log!(result, self.cmp(other) == Ordering::Equal);
321 result
322 }
323}
324impl Eq for MirRelationExpr {}
325
326impl MirRelationExpr {
327 /// Reports the schema of the relation.
328 ///
329 /// This is the SQL-type parallel of [`Self::typ`]; it is merely
330 /// a wrapper around it, returning a [`SqlRelationType`] instead of
331 /// a [`ReprRelationType`].
332 pub fn sql_typ(&self) -> SqlRelationType {
333 let repr_typ = self.typ();
334 SqlRelationType::from_repr(&repr_typ)
335 }
336
337 /// Reports the repr schema of the relation.
338 ///
339 /// This method determines the type through recursive traversal of the
340 /// relation expression, drawing from the types of base collections.
341 /// As such, this is not an especially cheap method, and should be used
342 /// judiciously.
343 ///
344 /// The relation type is computed incrementally with a recursive post-order
345 /// traversal, that accumulates the input types for the relations yet to be
346 /// visited in `type_stack`.
347 pub fn typ(&self) -> ReprRelationType {
348 let mut type_stack = Vec::new();
349 #[allow(deprecated)]
350 self.visit_pre_post_nolimit(
351 &mut |e: &MirRelationExpr| -> Option<Vec<&MirRelationExpr>> {
352 match &e {
353 MirRelationExpr::Let { body, .. } => Some(vec![&*body]),
354 MirRelationExpr::LetRec { body, .. } => Some(vec![&*body]),
355 _ => None,
356 }
357 },
358 &mut |e: &MirRelationExpr| {
359 match e {
360 MirRelationExpr::Let { .. } => {
361 let body_typ = type_stack.pop().unwrap();
362 // Insert a dummy relation type for the value, since `typ_with_input_types`
363 // won't look at it, but expects the relation type of the body to be second.
364 type_stack.push(ReprRelationType::empty());
365 type_stack.push(body_typ);
366 }
367 MirRelationExpr::LetRec { values, .. } => {
368 let body_typ = type_stack.pop().unwrap();
369 type_stack.extend(
370 std::iter::repeat(ReprRelationType::empty()).take(values.len()),
371 );
372 // Insert dummy relation types for the values, since `typ_with_input_types`
373 // won't look at them, but expects the relation type of the body to be last.
374 type_stack.push(body_typ);
375 }
376 _ => {}
377 }
378 let num_inputs = e.num_inputs();
379 let relation_type =
380 e.typ_with_input_types(&type_stack[type_stack.len() - num_inputs..]);
381 type_stack.truncate(type_stack.len() - num_inputs);
382 type_stack.push(relation_type);
383 },
384 );
385 assert_eq!(type_stack.len(), 1);
386 type_stack.pop().unwrap()
387 }
388
389 /// Reports the repr schema of the relation given the repr schema of the input relations.
390 pub fn typ_with_input_types(&self, input_types: &[ReprRelationType]) -> ReprRelationType {
391 let column_types = self.col_with_input_cols(input_types.iter().map(|i| &i.column_types));
392 let unique_keys = self.keys_with_input_keys(
393 input_types.iter().map(|i| i.arity()),
394 input_types.iter().map(|i| &i.keys),
395 );
396 ReprRelationType::new(column_types).with_keys(unique_keys)
397 }
398
399 /// Reports the column types of the relation given the column types of the
400 /// input relations.
401 ///
402 /// This method delegates to `try_col_with_input_cols`, panicking if an `Err`
403 /// variant is returned.
404 pub fn col_with_input_cols<'a, I>(&self, input_types: I) -> Vec<ReprColumnType>
405 where
406 I: Iterator<Item = &'a Vec<ReprColumnType>>,
407 {
408 match self.try_col_with_input_cols(input_types) {
409 Ok(col_types) => col_types,
410 Err(err) => panic!("{err}"),
411 }
412 }
413
414 /// Reports the column types of the relation given the column types of the input relations.
415 ///
416 /// `input_types` is required to contain the column types for the input relations of
417 /// the current relation in the same order as they are visited by `try_visit_children`
418 /// method, even though not all may be used for computing the schema of the
419 /// current relation. For example, `Let` expects two input types, one for the
420 /// value relation and one for the body, in that order, but only the one for the
421 /// body is used to determine the type of the `Let` relation.
422 ///
423 /// It is meant to be used during post-order traversals to compute column types
424 /// incrementally.
425 pub fn try_col_with_input_cols<'a, I>(
426 &self,
427 mut input_types: I,
428 ) -> Result<Vec<ReprColumnType>, String>
429 where
430 I: Iterator<Item = &'a Vec<ReprColumnType>>,
431 {
432 use MirRelationExpr::*;
433
434 let col_types = match self {
435 Constant { rows, typ } => {
436 let mut col_types = typ.column_types.clone();
437 let mut seen_null = vec![false; typ.arity()];
438 if let Ok(rows) = rows {
439 for (row, _diff) in rows {
440 for (datum, i) in row.iter().zip_eq(0..typ.arity()) {
441 if datum.is_null() {
442 seen_null[i] = true;
443 }
444 }
445 }
446 }
447 for (&seen_null, i) in seen_null.iter().zip_eq(0..typ.arity()) {
448 if !seen_null {
449 col_types[i].nullable = false;
450 } else {
451 assert!(col_types[i].nullable);
452 }
453 }
454 col_types
455 }
456 Get { typ, .. } => typ.column_types.clone(),
457 Project { outputs, .. } => {
458 let input = input_types.next().unwrap();
459 outputs.iter().map(|&i| input[i].clone()).collect()
460 }
461 Map { scalars, .. } => {
462 let mut result = input_types.next().unwrap().clone();
463 for scalar in scalars.iter() {
464 result.push(scalar.typ(&result))
465 }
466 result
467 }
468 FlatMap { func, .. } => {
469 let mut result = input_types.next().unwrap().clone();
470 result.extend(
471 func.output_sql_type()
472 .column_types
473 .iter()
474 .map(ReprColumnType::from),
475 );
476 result
477 }
478 Filter { predicates, .. } => {
479 let mut result = input_types.next().unwrap().clone();
480
481 // Set as nonnull any columns where null values would cause
482 // any predicate to evaluate to null.
483 for column in non_nullable_columns(predicates) {
484 result[column].nullable = false;
485 }
486 result
487 }
488 Join { equivalences, .. } => {
489 // Concatenate input column types
490 let mut types = input_types.flat_map(|cols| cols.to_owned()).collect_vec();
491 // In an equivalence class, if any column is non-null, then make all non-null
492 for equivalence in equivalences {
493 let col_inds = equivalence
494 .iter()
495 .filter_map(|expr| match expr {
496 MirScalarExpr::Column(col, _name) => Some(*col),
497 _ => None,
498 })
499 .collect_vec();
500 if col_inds.iter().any(|i| !types.get(*i).unwrap().nullable) {
501 for i in col_inds {
502 types.get_mut(i).unwrap().nullable = false;
503 }
504 }
505 }
506 types
507 }
508 Reduce {
509 group_key,
510 aggregates,
511 ..
512 } => {
513 let input = input_types.next().unwrap();
514 group_key
515 .iter()
516 .map(|e| e.typ(input))
517 .chain(aggregates.iter().map(|agg| agg.typ(input)))
518 .collect()
519 }
520 TopK { .. } | Negate { .. } | Threshold { .. } | ArrangeBy { .. } => {
521 input_types.next().unwrap().clone()
522 }
523 Let { .. } => {
524 // skip over the input types for `value`.
525 input_types.nth(1).unwrap().clone()
526 }
527 LetRec { values, .. } => {
528 // skip over the input types for `values`.
529 input_types.nth(values.len()).unwrap().clone()
530 }
531 Union { .. } => {
532 let mut result = input_types.next().unwrap().clone();
533 for input_col_types in input_types {
534 for (base_col, col) in result.iter_mut().zip_eq(input_col_types) {
535 *base_col = base_col
536 .union(col)
537 .map_err(|e| format!("{}\nin plan:\n{}", e, self.pretty()))?;
538 }
539 }
540 result
541 }
542 };
543
544 Ok(col_types)
545 }
546
547 /// Reports the unique keys of the relation given the arities and the unique
548 /// keys of the input relations.
549 ///
550 /// `input_arities` and `input_keys` are required to contain the
551 /// corresponding info for the input relations of
552 /// the current relation in the same order as they are visited by `try_visit_children`
553 /// method, even though not all may be used for computing the schema of the
554 /// current relation. For example, `Let` expects two input types, one for the
555 /// value relation and one for the body, in that order, but only the one for the
556 /// body is used to determine the type of the `Let` relation.
557 ///
558 /// It is meant to be used during post-order traversals to compute unique keys
559 /// incrementally.
560 pub fn keys_with_input_keys<'a, I, J>(
561 &self,
562 mut input_arities: I,
563 mut input_keys: J,
564 ) -> Vec<Vec<usize>>
565 where
566 I: Iterator<Item = usize>,
567 J: Iterator<Item = &'a Vec<Vec<usize>>>,
568 {
569 use MirRelationExpr::*;
570
571 let mut keys = match self {
572 Constant {
573 rows: Ok(rows),
574 typ,
575 } => {
576 let n_cols = typ.arity();
577 // If the `i`th entry is `Some`, then we have not yet observed non-uniqueness in the `i`th column.
578 let mut unique_values_per_col = vec![Some(BTreeSet::<Datum>::default()); n_cols];
579 for (row, diff) in rows {
580 for (i, datum) in row.iter().enumerate() {
581 if datum != Datum::Dummy {
582 if let Some(unique_vals) = &mut unique_values_per_col[i] {
583 let is_dupe = *diff != Diff::ONE || !unique_vals.insert(datum);
584 if is_dupe {
585 unique_values_per_col[i] = None;
586 }
587 }
588 }
589 }
590 }
591 if rows.len() == 0 || (rows.len() == 1 && rows[0].1 == Diff::ONE) {
592 vec![vec![]]
593 } else {
594 // XXX - Multi-column keys are not detected.
595 typ.keys
596 .iter()
597 .cloned()
598 .chain(
599 unique_values_per_col
600 .into_iter()
601 .enumerate()
602 .filter(|(_idx, unique_vals)| unique_vals.is_some())
603 .map(|(idx, _)| vec![idx]),
604 )
605 .collect()
606 }
607 }
608 Constant { rows: Err(_), typ } | Get { typ, .. } => typ.keys.clone(),
609 Threshold { .. } | ArrangeBy { .. } => input_keys.next().unwrap().clone(),
610 Let { .. } => {
611 // skip over the unique keys for value
612 input_keys.nth(1).unwrap().clone()
613 }
614 LetRec { values, .. } => {
615 // skip over the unique keys for value
616 input_keys.nth(values.len()).unwrap().clone()
617 }
618 Project { outputs, .. } => {
619 let input = input_keys.next().unwrap();
620 input
621 .iter()
622 .filter_map(|key_set| {
623 if key_set.iter().all(|k| outputs.contains(k)) {
624 Some(
625 key_set
626 .iter()
627 .map(|c| outputs.iter().position(|o| o == c).unwrap())
628 .collect(),
629 )
630 } else {
631 None
632 }
633 })
634 .collect()
635 }
636 Map { scalars, .. } => {
637 let mut remappings = Vec::new();
638 let arity = input_arities.next().unwrap();
639 for (column, scalar) in scalars.iter().enumerate() {
640 // assess whether the scalar preserves uniqueness,
641 // and could participate in a key!
642
643 fn uniqueness(expr: &MirScalarExpr) -> Option<usize> {
644 match expr {
645 MirScalarExpr::CallUnary { func, expr } => {
646 if func.preserves_uniqueness() {
647 uniqueness(expr)
648 } else {
649 None
650 }
651 }
652 MirScalarExpr::Column(c, _name) => Some(*c),
653 _ => None,
654 }
655 }
656
657 if let Some(c) = uniqueness(scalar) {
658 remappings.push((c, column + arity));
659 }
660 }
661
662 let mut result = input_keys.next().unwrap().clone();
663 let mut new_keys = Vec::new();
664 // Any column in `remappings` could be replaced in a key
665 // by the corresponding c. This could lead to combinatorial
666 // explosion using our current representation, so we wont
667 // do that. Instead, we'll handle the case of one remapping.
668 if remappings.len() == 1 {
669 let (old, new) = remappings.pop().unwrap();
670 for key in &result {
671 if key.contains(&old) {
672 let mut new_key: Vec<usize> =
673 key.iter().cloned().filter(|k| k != &old).collect();
674 new_key.push(new);
675 new_key.sort_unstable();
676 new_keys.push(new_key);
677 }
678 }
679 result.append(&mut new_keys);
680 }
681 result
682 }
683 FlatMap { .. } => {
684 // FlatMap can add duplicate rows, so input keys are no longer
685 // valid
686 vec![]
687 }
688 Negate { .. } => {
689 // Although negate may have distinct records for each key,
690 // the multiplicity is -1 rather than 1. This breaks many
691 // of the optimization uses of "keys".
692 vec![]
693 }
694 Filter { predicates, .. } => {
695 // A filter inherits the keys of its input unless the filters
696 // have reduced the input to a single row, in which case the
697 // keys of the input are `()`.
698 let mut input = input_keys.next().unwrap().clone();
699
700 if !input.is_empty() {
701 // Track columns equated to literals, which we can prune.
702 let mut cols_equal_to_literal = BTreeSet::new();
703
704 // Perform union find on `col1 = col2` to establish
705 // connected components of equated columns. Absent any
706 // equalities, this will be `0 .. #c` (where #c is the
707 // greatest column referenced by a predicate), but each
708 // equality will orient the root of the greater to the root
709 // of the lesser.
710 let mut union_find = Vec::new();
711
712 for expr in predicates.iter() {
713 if let MirScalarExpr::CallBinary {
714 func: crate::BinaryFunc::Eq(_),
715 expr1,
716 expr2,
717 } = expr
718 {
719 if let MirScalarExpr::Column(c, _name) = &**expr1 {
720 if expr2.is_literal_ok() {
721 cols_equal_to_literal.insert(c);
722 }
723 }
724 if let MirScalarExpr::Column(c, _name) = &**expr2 {
725 if expr1.is_literal_ok() {
726 cols_equal_to_literal.insert(c);
727 }
728 }
729 // Perform union-find to equate columns.
730 if let (Some(c1), Some(c2)) = (expr1.as_column(), expr2.as_column()) {
731 if c1 != c2 {
732 // Ensure union_find has entries up to
733 // max(c1, c2) by filling up missing
734 // positions with identity mappings.
735 while union_find.len() <= std::cmp::max(c1, c2) {
736 union_find.push(union_find.len());
737 }
738 let mut r1 = c1; // Find the representative column of [c1].
739 while r1 != union_find[r1] {
740 assert!(union_find[r1] < r1);
741 r1 = union_find[r1];
742 }
743 let mut r2 = c2; // Find the representative column of [c2].
744 while r2 != union_find[r2] {
745 assert!(union_find[r2] < r2);
746 r2 = union_find[r2];
747 }
748 // Union [c1] and [c2] by pointing the
749 // larger to the smaller representative (we
750 // update the remaining equivalence class
751 // members only once after this for-loop).
752 union_find[std::cmp::max(r1, r2)] = std::cmp::min(r1, r2);
753 }
754 }
755 }
756 }
757
758 // Complete union-find by pointing each element at its representative column.
759 for i in 0..union_find.len() {
760 // Iteration not required, as each prior already references the right column.
761 union_find[i] = union_find[union_find[i]];
762 }
763
764 // Remove columns bound to literals, and remap columns equated to earlier columns.
765 // We will re-expand remapped columns in a moment, but this avoids exponential work.
766 for key_set in &mut input {
767 key_set.retain(|k| !cols_equal_to_literal.contains(&k));
768 for col in key_set.iter_mut() {
769 if let Some(equiv) = union_find.get(*col) {
770 *col = *equiv;
771 }
772 }
773 key_set.sort();
774 key_set.dedup();
775 }
776 input.sort();
777 input.dedup();
778
779 // Expand out each key to each of its equivalent forms.
780 // Each instance of `col` can be replaced by any equivalent column.
781 // This has the potential to result in exponentially sized number of unique keys,
782 // and in the future we should probably maintain unique keys modulo equivalence.
783
784 // First, compute an inverse map from each representative
785 // column `sub` to all other equivalent columns `col`.
786 let mut subs = Vec::new();
787 for (col, sub) in union_find.iter().enumerate() {
788 if *sub != col {
789 assert!(*sub < col);
790 while subs.len() <= *sub {
791 subs.push(Vec::new());
792 }
793 subs[*sub].push(col);
794 }
795 }
796 // For each column, substitute for it in each occurrence.
797 let mut to_add = Vec::new();
798 for (col, subs) in subs.iter().enumerate() {
799 if !subs.is_empty() {
800 for key_set in input.iter() {
801 if key_set.contains(&col) {
802 let mut to_extend = key_set.clone();
803 to_extend.retain(|c| c != &col);
804 for sub in subs {
805 to_extend.push(*sub);
806 to_add.push(to_extend.clone());
807 to_extend.pop();
808 }
809 }
810 }
811 }
812 // No deduplication, as we cannot introduce duplicates.
813 input.append(&mut to_add);
814 }
815 for key_set in input.iter_mut() {
816 key_set.sort();
817 key_set.dedup();
818 }
819 }
820 input
821 }
822 Join { equivalences, .. } => {
823 // It is important the `new_from_input_arities` constructor is
824 // used. Otherwise, Materialize may potentially end up in an
825 // infinite loop.
826 let input_mapper = crate::JoinInputMapper::new_from_input_arities(input_arities);
827
828 input_mapper.global_keys(input_keys, equivalences)
829 }
830 Reduce { group_key, .. } => {
831 // The group key should form a key, but we might already have
832 // keys that are subsets of the group key, and should retain
833 // those instead, if so.
834 let mut result = Vec::new();
835 for key_set in input_keys.next().unwrap() {
836 if key_set
837 .iter()
838 .all(|k| group_key.contains(&MirScalarExpr::column(*k)))
839 {
840 result.push(
841 key_set
842 .iter()
843 .map(|i| {
844 group_key
845 .iter()
846 .position(|k| k == &MirScalarExpr::column(*i))
847 .unwrap()
848 })
849 .collect::<Vec<_>>(),
850 );
851 }
852 }
853 if result.is_empty() {
854 result.push((0..group_key.len()).collect());
855 }
856 result
857 }
858 TopK {
859 group_key, limit, ..
860 } => {
861 // If `limit` is `Some(1)` then the group key will become
862 // a unique key, as there will be only one record with that key.
863 let mut result = input_keys.next().unwrap().clone();
864 if limit.as_ref().and_then(|x| x.as_literal_int64()) == Some(1) {
865 result.push(group_key.clone())
866 }
867 result
868 }
869 Union { base, inputs } => {
870 // Generally, unions do not have any unique keys, because
871 // each input might duplicate some. However, there is at
872 // least one idiomatic structure that does preserve keys,
873 // which results from SQL aggregations that must populate
874 // absent records with default values. In that pattern,
875 // the union of one GET with its negation, which has first
876 // been subjected to a projection and map, we can remove
877 // their influence on the key structure.
878 //
879 // If there are A, B, each with a unique `key` such that
880 // we are looking at
881 //
882 // A.proj(set_containing_key) + (B - A.proj(key)).map(stuff)
883 //
884 // Then we can report `key` as a unique key.
885 //
886 // TODO: make unique key structure an optimization analysis
887 // rather than part of the type information.
888 // TODO: perhaps ensure that (above) A.proj(key) is a
889 // subset of B, as otherwise there are negative records
890 // and who knows what is true (not expected, but again
891 // who knows what the query plan might look like).
892
893 let arity = input_arities.next().unwrap();
894 let (base_projection, base_with_project_stripped) =
895 if let MirRelationExpr::Project { input, outputs } = &**base {
896 (outputs.clone(), &**input)
897 } else {
898 // A input without a project is equivalent to an input
899 // with the project being all columns in the input in order.
900 ((0..arity).collect::<Vec<_>>(), &**base)
901 };
902 let mut result = Vec::new();
903 if let MirRelationExpr::Get {
904 id: first_id,
905 typ: _,
906 ..
907 } = base_with_project_stripped
908 {
909 if inputs.len() == 1 {
910 if let MirRelationExpr::Map { input, .. } = &inputs[0] {
911 if let MirRelationExpr::Union { base, inputs } = &**input {
912 if inputs.len() == 1 {
913 if let Some((input, outputs)) = base.is_negated_project() {
914 if let MirRelationExpr::Get {
915 id: second_id,
916 typ: _,
917 ..
918 } = input
919 {
920 if first_id == second_id {
921 result.extend(
922 input_keys
923 .next()
924 .unwrap()
925 .into_iter()
926 .filter(|key| {
927 key.iter().all(|c| {
928 outputs.get(*c) == Some(c)
929 && base_projection.get(*c)
930 == Some(c)
931 })
932 })
933 .cloned(),
934 );
935 }
936 }
937 }
938 }
939 }
940 }
941 }
942 }
943 // Important: do not inherit keys of either input, as not unique.
944 result
945 }
946 };
947 keys.sort();
948 keys.dedup();
949 keys
950 }
951
952 /// The number of columns in the relation.
953 ///
954 /// This number is determined from the type, which is determined recursively
955 /// at non-trivial cost.
956 ///
957 /// The arity is computed incrementally with a recursive post-order
958 /// traversal, that accumulates the arities for the relations yet to be
959 /// visited in `arity_stack`.
960 pub fn arity(&self) -> usize {
961 let mut arity_stack = Vec::new();
962 #[allow(deprecated)]
963 self.visit_pre_post_nolimit(
964 &mut |e: &MirRelationExpr| -> Option<Vec<&MirRelationExpr>> {
965 match &e {
966 MirRelationExpr::Let { body, .. } => {
967 // Do not traverse the value sub-graph, since it's not relevant for
968 // determining the arity of Let operators.
969 Some(vec![&*body])
970 }
971 MirRelationExpr::LetRec { body, .. } => {
972 // Do not traverse the value sub-graph, since it's not relevant for
973 // determining the arity of Let operators.
974 Some(vec![&*body])
975 }
976 MirRelationExpr::Project { .. } | MirRelationExpr::Reduce { .. } => {
977 // No further traversal is required; these operators know their arity.
978 Some(Vec::new())
979 }
980 _ => None,
981 }
982 },
983 &mut |e: &MirRelationExpr| {
984 match &e {
985 MirRelationExpr::Let { .. } => {
986 let body_arity = arity_stack.pop().unwrap();
987 arity_stack.push(0);
988 arity_stack.push(body_arity);
989 }
990 MirRelationExpr::LetRec { values, .. } => {
991 let body_arity = arity_stack.pop().unwrap();
992 arity_stack.extend(std::iter::repeat(0).take(values.len()));
993 arity_stack.push(body_arity);
994 }
995 MirRelationExpr::Project { .. } | MirRelationExpr::Reduce { .. } => {
996 arity_stack.push(0);
997 }
998 _ => {}
999 }
1000 let num_inputs = e.num_inputs();
1001 let input_arities = arity_stack.drain(arity_stack.len() - num_inputs..);
1002 let arity = e.arity_with_input_arities(input_arities);
1003 arity_stack.push(arity);
1004 },
1005 );
1006 assert_eq!(arity_stack.len(), 1);
1007 arity_stack.pop().unwrap()
1008 }
1009
1010 /// Reports the arity of the relation given the schema of the input relations.
1011 ///
1012 /// `input_arities` is required to contain the arities for the input relations of
1013 /// the current relation in the same order as they are visited by `try_visit_children`
1014 /// method, even though not all may be used for computing the schema of the
1015 /// current relation. For example, `Let` expects two input types, one for the
1016 /// value relation and one for the body, in that order, but only the one for the
1017 /// body is used to determine the type of the `Let` relation.
1018 ///
1019 /// It is meant to be used during post-order traversals to compute arities
1020 /// incrementally.
1021 pub fn arity_with_input_arities<I>(&self, mut input_arities: I) -> usize
1022 where
1023 I: Iterator<Item = usize>,
1024 {
1025 use MirRelationExpr::*;
1026
1027 match self {
1028 Constant { rows: _, typ } => typ.arity(),
1029 Get { typ, .. } => typ.arity(),
1030 Let { .. } => {
1031 input_arities.next();
1032 input_arities.next().unwrap()
1033 }
1034 LetRec { values, .. } => {
1035 for _ in 0..values.len() {
1036 input_arities.next();
1037 }
1038 input_arities.next().unwrap()
1039 }
1040 Project { outputs, .. } => outputs.len(),
1041 Map { scalars, .. } => input_arities.next().unwrap() + scalars.len(),
1042 FlatMap { func, .. } => input_arities.next().unwrap() + func.output_arity(),
1043 Join { .. } => input_arities.sum(),
1044 Reduce {
1045 input: _,
1046 group_key,
1047 aggregates,
1048 ..
1049 } => group_key.len() + aggregates.len(),
1050 Filter { .. }
1051 | TopK { .. }
1052 | Negate { .. }
1053 | Threshold { .. }
1054 | Union { .. }
1055 | ArrangeBy { .. } => input_arities.next().unwrap(),
1056 }
1057 }
1058
1059 /// The number of child relations this relation has.
1060 pub fn num_inputs(&self) -> usize {
1061 let mut count = 0;
1062
1063 self.visit_children(|_| count += 1);
1064
1065 count
1066 }
1067
1068 /// Constructs a constant collection from specific rows and schema, where
1069 /// each row will have a multiplicity of one.
1070 pub fn constant(rows: Vec<Vec<Datum>>, typ: ReprRelationType) -> Self {
1071 let rows = rows.into_iter().map(|row| (row, Diff::ONE)).collect();
1072 MirRelationExpr::constant_diff(rows, typ)
1073 }
1074
1075 /// Constructs a constant collection from specific rows and schema, where
1076 /// each row can have an arbitrary multiplicity.
1077 pub fn constant_diff(rows: Vec<(Vec<Datum>, Diff)>, typ: ReprRelationType) -> Self {
1078 for (row, _diff) in &rows {
1079 for (datum, column_typ) in row.iter().zip_eq(typ.column_types.iter()) {
1080 assert!(
1081 datum.is_instance_of(column_typ),
1082 "Expected datum of type {:?}, got value {:?}",
1083 column_typ,
1084 datum
1085 );
1086 }
1087 }
1088 let rows = Ok(rows
1089 .into_iter()
1090 .map(move |(row, diff)| (Row::pack_slice(&row), diff))
1091 .collect());
1092 MirRelationExpr::Constant { rows, typ }
1093 }
1094
1095 /// If self is a constant, return the value and the type, otherwise `None`.
1096 /// Looks behind `ArrangeBy`s.
1097 pub fn as_const(&self) -> Option<(&Result<Vec<(Row, Diff)>, EvalError>, &ReprRelationType)> {
1098 match self {
1099 MirRelationExpr::Constant { rows, typ } => Some((rows, typ)),
1100 MirRelationExpr::ArrangeBy { input, .. } => input.as_const(),
1101 _ => None,
1102 }
1103 }
1104
1105 /// If self is a constant, mutably return the value and the type, otherwise `None`.
1106 /// Looks behind `ArrangeBy`s.
1107 pub fn as_const_mut(
1108 &mut self,
1109 ) -> Option<(
1110 &mut Result<Vec<(Row, Diff)>, EvalError>,
1111 &mut ReprRelationType,
1112 )> {
1113 match self {
1114 MirRelationExpr::Constant { rows, typ } => Some((rows, typ)),
1115 MirRelationExpr::ArrangeBy { input, .. } => input.as_const_mut(),
1116 _ => None,
1117 }
1118 }
1119
1120 /// If self is a constant error, return the error, otherwise `None`.
1121 /// Looks behind `ArrangeBy`s.
1122 pub fn as_const_err(&self) -> Option<&EvalError> {
1123 match self {
1124 MirRelationExpr::Constant { rows: Err(e), .. } => Some(e),
1125 MirRelationExpr::ArrangeBy { input, .. } => input.as_const_err(),
1126 _ => None,
1127 }
1128 }
1129
1130 /// Checks if `self` is the single element collection with no columns.
1131 pub fn is_constant_singleton(&self) -> bool {
1132 if let Some((Ok(rows), typ)) = self.as_const() {
1133 rows.len() == 1 && typ.column_types.len() == 0 && rows[0].1 == Diff::ONE
1134 } else {
1135 false
1136 }
1137 }
1138
1139 /// Constructs the expression for getting a local collection.
1140 pub fn local_get(id: LocalId, typ: ReprRelationType) -> Self {
1141 MirRelationExpr::Get {
1142 id: Id::Local(id),
1143 typ,
1144 access_strategy: AccessStrategy::UnknownOrLocal,
1145 }
1146 }
1147
1148 /// Constructs the expression for getting a global collection
1149 pub fn global_get(id: GlobalId, typ: ReprRelationType) -> Self {
1150 MirRelationExpr::Get {
1151 id: Id::Global(id),
1152 typ,
1153 access_strategy: AccessStrategy::UnknownOrLocal,
1154 }
1155 }
1156
1157 /// Retains only the columns specified by `output`.
1158 pub fn project(mut self, mut outputs: Vec<usize>) -> Self {
1159 if let MirRelationExpr::Project {
1160 outputs: columns, ..
1161 } = &mut self
1162 {
1163 // Update `outputs` to reference base columns of `input`.
1164 for column in outputs.iter_mut() {
1165 *column = columns[*column];
1166 }
1167 *columns = outputs;
1168 self
1169 } else {
1170 MirRelationExpr::Project {
1171 input: Box::new(self),
1172 outputs,
1173 }
1174 }
1175 }
1176
1177 /// Append to each row the results of applying elements of `scalar`.
1178 pub fn map(mut self, scalars: Vec<MirScalarExpr>) -> Self {
1179 if let MirRelationExpr::Map { scalars: s, .. } = &mut self {
1180 s.extend(scalars);
1181 self
1182 } else if !scalars.is_empty() {
1183 MirRelationExpr::Map {
1184 input: Box::new(self),
1185 scalars,
1186 }
1187 } else {
1188 self
1189 }
1190 }
1191
1192 /// Append to each row a single `scalar`.
1193 pub fn map_one(self, scalar: MirScalarExpr) -> Self {
1194 self.map(vec![scalar])
1195 }
1196
1197 /// Like `map`, but yields zero-or-more output rows per input row
1198 pub fn flat_map(self, func: TableFunc, exprs: Vec<MirScalarExpr>) -> Self {
1199 MirRelationExpr::FlatMap {
1200 input: Box::new(self),
1201 func,
1202 exprs,
1203 }
1204 }
1205
1206 /// Retain only the rows satisfying each of several predicates.
1207 pub fn filter<I>(mut self, predicates: I) -> Self
1208 where
1209 I: IntoIterator<Item = MirScalarExpr>,
1210 {
1211 // Extract existing predicates
1212 let mut new_predicates = if let MirRelationExpr::Filter { input, predicates } = self {
1213 self = *input;
1214 predicates
1215 } else {
1216 Vec::new()
1217 };
1218 // Normalize collection of predicates.
1219 new_predicates.extend(predicates);
1220 new_predicates.retain(|p| !p.is_literal_true());
1221 new_predicates.sort();
1222 new_predicates.dedup();
1223 // Introduce a `Filter` only if we have predicates.
1224 if !new_predicates.is_empty() {
1225 self = MirRelationExpr::Filter {
1226 input: Box::new(self),
1227 predicates: new_predicates,
1228 };
1229 }
1230
1231 self
1232 }
1233
1234 /// Form the Cartesian outer-product of rows in both inputs.
1235 pub fn product(mut self, right: Self) -> Self {
1236 if right.is_constant_singleton() {
1237 self
1238 } else if self.is_constant_singleton() {
1239 right
1240 } else if let MirRelationExpr::Join { inputs, .. } = &mut self {
1241 inputs.push(right);
1242 self
1243 } else {
1244 MirRelationExpr::join(vec![self, right], vec![])
1245 }
1246 }
1247
1248 /// Performs a relational equijoin among the input collections.
1249 ///
1250 /// The sequence `inputs` each describe different input collections, and the sequence `variables` describes
1251 /// equality constraints that some of their columns must satisfy. Each element in `variable` describes a set
1252 /// of pairs `(input_index, column_index)` where every value described by that set must be equal.
1253 ///
1254 /// For example, the pair `(input, column)` indexes into `inputs[input][column]`, extracting the `input`th
1255 /// input collection and for each row examining its `column`th column.
1256 ///
1257 /// # Example
1258 ///
1259 /// ```rust
1260 /// use mz_repr::{Datum, SqlColumnType, ReprRelationType, ReprScalarType};
1261 /// use mz_expr::MirRelationExpr;
1262 ///
1263 /// // A common schema for each input.
1264 /// let schema = ReprRelationType::new(vec![
1265 /// ReprScalarType::Int32.nullable(false),
1266 /// ReprScalarType::Int32.nullable(false),
1267 /// ]);
1268 ///
1269 /// // the specific data are not important here.
1270 /// let data = vec![Datum::Int32(0), Datum::Int32(1)];
1271 ///
1272 /// // Three collections that could have been different.
1273 /// let input0 = MirRelationExpr::constant(vec![data.clone()], schema.clone());
1274 /// let input1 = MirRelationExpr::constant(vec![data.clone()], schema.clone());
1275 /// let input2 = MirRelationExpr::constant(vec![data.clone()], schema.clone());
1276 ///
1277 /// // Join the three relations looking for triangles, like so.
1278 /// //
1279 /// // Output(A,B,C) := Input0(A,B), Input1(B,C), Input2(A,C)
1280 /// let joined = MirRelationExpr::join(
1281 /// vec![input0, input1, input2],
1282 /// vec![
1283 /// vec![(0,0), (2,0)], // fields A of inputs 0 and 2.
1284 /// vec![(0,1), (1,0)], // fields B of inputs 0 and 1.
1285 /// vec![(1,1), (2,1)], // fields C of inputs 1 and 2.
1286 /// ],
1287 /// );
1288 ///
1289 /// // Technically the above produces `Output(A,B,B,C,A,C)` because the columns are concatenated.
1290 /// // A projection resolves this and produces the correct output.
1291 /// let result = joined.project(vec![0, 1, 3]);
1292 /// ```
1293 pub fn join(inputs: Vec<MirRelationExpr>, variables: Vec<Vec<(usize, usize)>>) -> Self {
1294 let input_mapper = join_input_mapper::JoinInputMapper::new(&inputs);
1295
1296 let equivalences = variables
1297 .into_iter()
1298 .map(|vs| {
1299 vs.into_iter()
1300 .map(|(r, c)| input_mapper.map_expr_to_global(MirScalarExpr::column(c), r))
1301 .collect::<Vec<_>>()
1302 })
1303 .collect::<Vec<_>>();
1304
1305 Self::join_scalars(inputs, equivalences)
1306 }
1307
1308 /// Constructs a join operator from inputs and required-equal scalar expressions.
1309 pub fn join_scalars(
1310 mut inputs: Vec<MirRelationExpr>,
1311 equivalences: Vec<Vec<MirScalarExpr>>,
1312 ) -> Self {
1313 // Remove all constant inputs that are the identity for join.
1314 // They neither introduce nor modify any column references.
1315 inputs.retain(|i| !i.is_constant_singleton());
1316 MirRelationExpr::Join {
1317 inputs,
1318 equivalences,
1319 implementation: JoinImplementation::Unimplemented,
1320 }
1321 }
1322
1323 /// Perform a key-wise reduction / aggregation.
1324 ///
1325 /// The `group_key` argument indicates columns in the input collection that should
1326 /// be grouped, and `aggregates` lists aggregation functions each of which produces
1327 /// one output column in addition to the keys.
1328 pub fn reduce(
1329 self,
1330 group_key: Vec<usize>,
1331 aggregates: Vec<AggregateExpr>,
1332 expected_group_size: Option<u64>,
1333 ) -> Self {
1334 MirRelationExpr::Reduce {
1335 input: Box::new(self),
1336 group_key: group_key.into_iter().map(MirScalarExpr::column).collect(),
1337 aggregates,
1338 monotonic: false,
1339 expected_group_size,
1340 }
1341 }
1342
1343 /// Perform a key-wise reduction order by and limit.
1344 ///
1345 /// The `group_key` argument indicates columns in the input collection that should
1346 /// be grouped, the `order_key` argument indicates columns that should be further
1347 /// used to order records within groups, and the `limit` argument constrains the
1348 /// total number of records that should be produced in each group.
1349 pub fn top_k(
1350 self,
1351 group_key: Vec<usize>,
1352 order_key: Vec<ColumnOrder>,
1353 limit: Option<MirScalarExpr>,
1354 offset: usize,
1355 expected_group_size: Option<u64>,
1356 ) -> Self {
1357 MirRelationExpr::TopK {
1358 input: Box::new(self),
1359 group_key,
1360 order_key,
1361 limit,
1362 offset,
1363 expected_group_size,
1364 monotonic: false,
1365 }
1366 }
1367
1368 /// Negates the occurrences of each row.
1369 pub fn negate(self) -> Self {
1370 if let MirRelationExpr::Negate { input } = self {
1371 *input
1372 } else {
1373 MirRelationExpr::Negate {
1374 input: Box::new(self),
1375 }
1376 }
1377 }
1378
1379 /// Removes all but the first occurrence of each row.
1380 pub fn distinct(self) -> Self {
1381 let arity = self.arity();
1382 self.distinct_by((0..arity).collect())
1383 }
1384
1385 /// Removes all but the first occurrence of each key. Columns not included
1386 /// in the `group_key` are discarded.
1387 pub fn distinct_by(self, group_key: Vec<usize>) -> Self {
1388 self.reduce(group_key, vec![], None)
1389 }
1390
1391 /// Discards rows with a negative frequency.
1392 pub fn threshold(self) -> Self {
1393 if let MirRelationExpr::Threshold { .. } = &self {
1394 self
1395 } else {
1396 MirRelationExpr::Threshold {
1397 input: Box::new(self),
1398 }
1399 }
1400 }
1401
1402 /// Unions together any number inputs.
1403 ///
1404 /// If `inputs` is empty, then an empty relation of type `typ` is
1405 /// constructed.
1406 pub fn union_many(mut inputs: Vec<Self>, typ: ReprRelationType) -> Self {
1407 // Deconstruct `inputs` as `Union`s and reconstitute.
1408 let mut flat_inputs = Vec::with_capacity(inputs.len());
1409 for input in inputs {
1410 if let MirRelationExpr::Union { base, inputs } = input {
1411 flat_inputs.push(*base);
1412 flat_inputs.extend(inputs);
1413 } else {
1414 flat_inputs.push(input);
1415 }
1416 }
1417 inputs = flat_inputs;
1418 if inputs.len() == 0 {
1419 MirRelationExpr::Constant {
1420 rows: Ok(vec![]),
1421 typ,
1422 }
1423 } else if inputs.len() == 1 {
1424 inputs.into_element()
1425 } else {
1426 MirRelationExpr::Union {
1427 base: Box::new(inputs.remove(0)),
1428 inputs,
1429 }
1430 }
1431 }
1432
1433 /// Produces one collection where each row is present with the sum of its frequencies in each input.
1434 pub fn union(self, other: Self) -> Self {
1435 // Deconstruct `self` and `other` as `Union`s and reconstitute.
1436 let mut flat_inputs = Vec::with_capacity(2);
1437 if let MirRelationExpr::Union { base, inputs } = self {
1438 flat_inputs.push(*base);
1439 flat_inputs.extend(inputs);
1440 } else {
1441 flat_inputs.push(self);
1442 }
1443 if let MirRelationExpr::Union { base, inputs } = other {
1444 flat_inputs.push(*base);
1445 flat_inputs.extend(inputs);
1446 } else {
1447 flat_inputs.push(other);
1448 }
1449
1450 MirRelationExpr::Union {
1451 base: Box::new(flat_inputs.remove(0)),
1452 inputs: flat_inputs,
1453 }
1454 }
1455
1456 /// Arranges the collection by the specified columns
1457 pub fn arrange_by(self, keys: &[Vec<MirScalarExpr>]) -> Self {
1458 MirRelationExpr::ArrangeBy {
1459 input: Box::new(self),
1460 keys: keys.to_owned(),
1461 }
1462 }
1463
1464 /// Indicates if this is a constant empty collection.
1465 ///
1466 /// A false value does not mean the collection is known to be non-empty,
1467 /// only that we cannot currently determine that it is statically empty.
1468 pub fn is_empty(&self) -> bool {
1469 if let Some((Ok(rows), ..)) = self.as_const() {
1470 rows.is_empty()
1471 } else {
1472 false
1473 }
1474 }
1475
1476 /// If the expression is a negated project, return the input and the projection.
1477 pub fn is_negated_project(&self) -> Option<(&MirRelationExpr, &[usize])> {
1478 if let MirRelationExpr::Negate { input } = self {
1479 if let MirRelationExpr::Project { input, outputs } = &**input {
1480 return Some((&**input, outputs));
1481 }
1482 }
1483 if let MirRelationExpr::Project { input, outputs } = self {
1484 if let MirRelationExpr::Negate { input } = &**input {
1485 return Some((&**input, outputs));
1486 }
1487 }
1488 None
1489 }
1490
1491 /// Pretty-print this [MirRelationExpr] to a string.
1492 pub fn pretty(&self) -> String {
1493 let config = ExplainConfig::default();
1494 self.debug_explain(&config, None)
1495 }
1496
1497 /// Pretty-print this [MirRelationExpr] to a string using a custom
1498 /// [ExplainConfig] and an optionally provided [ExprHumanizer].
1499 /// This is intended for debugging and tests, not users.
1500 pub fn debug_explain(
1501 &self,
1502 config: &ExplainConfig,
1503 humanizer: Option<&dyn ExprHumanizer>,
1504 ) -> String {
1505 text_string_at(self, || PlanRenderingContext {
1506 indent: Indent::default(),
1507 humanizer: humanizer.unwrap_or(&DummyHumanizer),
1508 annotations: BTreeMap::default(),
1509 config,
1510 ambiguous_ids: BTreeSet::default(),
1511 })
1512 }
1513
1514 /// Take ownership of `self`, leaving an empty `MirRelationExpr::Constant` with the optionally
1515 /// given scalar types. The given scalar types should be `base_eq` with the types that `typ()`
1516 /// would find. Keys and nullability are ignored in the given `SqlRelationType`, and instead we set
1517 /// the best possible key and nullability, since we are making an empty collection.
1518 ///
1519 /// If `typ` is not given, then this calls `.typ()` (which is possibly expensive) to determine
1520 /// the correct type.
1521 pub fn take_safely(&mut self, typ: Option<ReprRelationType>) -> MirRelationExpr {
1522 if let Some(typ) = &typ {
1523 let self_typ = self.typ();
1524 soft_assert_no_log!(
1525 self_typ
1526 .column_types
1527 .iter()
1528 .zip_eq(typ.column_types.iter())
1529 .all(|(t1, t2)| t1.scalar_type == t2.scalar_type)
1530 );
1531 }
1532 let mut typ = typ.unwrap_or_else(|| self.typ());
1533 typ.keys = vec![vec![]];
1534 for ct in typ.column_types.iter_mut() {
1535 ct.nullable = false;
1536 }
1537 std::mem::replace(
1538 self,
1539 MirRelationExpr::Constant {
1540 rows: Ok(vec![]),
1541 typ,
1542 },
1543 )
1544 }
1545
1546 /// Take ownership of `self`, leaving an empty `MirRelationExpr::Constant` with the given scalar
1547 /// types. Nullability is ignored in the given `SqlColumnType`s, and instead we set the best
1548 /// possible nullability, since we are making an empty collection.
1549 pub fn take_safely_with_sql_col_types(&mut self, typ: Vec<SqlColumnType>) -> MirRelationExpr {
1550 self.take_safely(Some(ReprRelationType::from(&SqlRelationType::new(typ))))
1551 }
1552
1553 /// Like [`Self::take_safely_with_col_types`], but accepts `Vec<ReprColumnType>`.
1554 ///
1555 /// This is the preferred entry point for optimizer transforms, where repr
1556 /// types are the native currency. Internally converts to [`SqlColumnType`]
1557 /// and delegates to [`Self::take_safely_with_col_types`].
1558 pub fn take_safely_with_col_types(&mut self, typ: Vec<ReprColumnType>) -> MirRelationExpr {
1559 self.take_safely(Some(ReprRelationType::new(typ)))
1560 }
1561
1562 /// Take ownership of `self`, leaving an empty `MirRelationExpr::Constant` with an **incorrect** type.
1563 ///
1564 /// This should only be used if `self` is about to be dropped or otherwise overwritten.
1565 pub fn take_dangerous(&mut self) -> MirRelationExpr {
1566 let empty = MirRelationExpr::Constant {
1567 rows: Ok(vec![]),
1568 typ: ReprRelationType::new(Vec::new()),
1569 };
1570 std::mem::replace(self, empty)
1571 }
1572
1573 /// Replaces `self` with some logic applied to `self`.
1574 pub fn replace_using<F>(&mut self, logic: F)
1575 where
1576 F: FnOnce(MirRelationExpr) -> MirRelationExpr,
1577 {
1578 let empty = MirRelationExpr::Constant {
1579 rows: Ok(vec![]),
1580 typ: ReprRelationType::new(Vec::new()),
1581 };
1582 let expr = std::mem::replace(self, empty);
1583 *self = logic(expr);
1584 }
1585
1586 /// Store `self` in a `Let` and pass the corresponding `Get` to `body`.
1587 pub fn let_in<Body, E>(self, id_gen: &mut IdGen, body: Body) -> Result<MirRelationExpr, E>
1588 where
1589 Body: FnOnce(&mut IdGen, MirRelationExpr) -> Result<MirRelationExpr, E>,
1590 {
1591 if let MirRelationExpr::Get { .. } = self {
1592 // already done
1593 body(id_gen, self)
1594 } else {
1595 let id = LocalId::new(id_gen.allocate_id());
1596 let get = MirRelationExpr::Get {
1597 id: Id::Local(id),
1598 typ: self.typ(),
1599 access_strategy: AccessStrategy::UnknownOrLocal,
1600 };
1601 let body = (body)(id_gen, get)?;
1602 Ok(MirRelationExpr::Let {
1603 id,
1604 value: Box::new(self),
1605 body: Box::new(body),
1606 })
1607 }
1608 }
1609
1610 /// Return every row in `self` that does not have a matching row in the first columns of `keys_and_values`, using `default` to fill in the remaining columns
1611 /// (If `default` is a row of nulls, this is the 'outer' part of LEFT OUTER JOIN)
1612 pub fn anti_lookup<E>(
1613 self,
1614 id_gen: &mut IdGen,
1615 keys_and_values: MirRelationExpr,
1616 default: Vec<(Datum, ReprScalarType)>,
1617 ) -> Result<MirRelationExpr, E> {
1618 let (data, column_types): (Vec<_>, Vec<_>) = default
1619 .into_iter()
1620 .map(|(datum, scalar_type)| {
1621 (
1622 datum,
1623 ReprColumnType {
1624 scalar_type,
1625 nullable: datum.is_null(),
1626 },
1627 )
1628 })
1629 .unzip();
1630 assert_eq!(keys_and_values.arity() - self.arity(), data.len());
1631 self.let_in(id_gen, |_id_gen, get_keys| {
1632 let get_keys_arity = get_keys.arity();
1633 Ok(MirRelationExpr::join(
1634 vec![
1635 // all the missing keys (with count 1)
1636 keys_and_values
1637 .distinct_by((0..get_keys_arity).collect())
1638 .negate()
1639 .union(get_keys.clone().distinct()),
1640 // join with keys to get the correct counts
1641 get_keys.clone(),
1642 ],
1643 (0..get_keys_arity).map(|i| vec![(0, i), (1, i)]).collect(),
1644 )
1645 // get rid of the extra copies of columns from keys
1646 .project((0..get_keys_arity).collect())
1647 // This join is logically equivalent to
1648 // `.map(<default_expr>)`, but using a join allows for
1649 // potential predicate pushdown and elision in the
1650 // optimizer.
1651 .product(MirRelationExpr::constant(
1652 vec![data],
1653 ReprRelationType::new(column_types),
1654 )))
1655 })
1656 }
1657
1658 /// Return:
1659 /// * every row in keys_and_values
1660 /// * every row in `self` that does not have a matching row in the first columns of
1661 /// `keys_and_values`, using `default` to fill in the remaining columns
1662 /// (This is LEFT OUTER JOIN if:
1663 /// 1) `default` is a row of null
1664 /// 2) matching rows in `keys_and_values` and `self` have the same multiplicity.)
1665 pub fn lookup<E>(
1666 self,
1667 id_gen: &mut IdGen,
1668 keys_and_values: MirRelationExpr,
1669 default: Vec<(Datum<'static>, ReprScalarType)>,
1670 ) -> Result<MirRelationExpr, E> {
1671 keys_and_values.let_in(id_gen, |id_gen, get_keys_and_values| {
1672 Ok(get_keys_and_values.clone().union(self.anti_lookup(
1673 id_gen,
1674 get_keys_and_values,
1675 default,
1676 )?))
1677 })
1678 }
1679
1680 /// True iff the expression contains a `NullaryFunc::MzLogicalTimestamp`.
1681 pub fn contains_temporal(&self) -> bool {
1682 let mut contains = false;
1683 self.visit_scalars(&mut |e| contains = contains || e.contains_temporal());
1684 contains
1685 }
1686
1687 /// Fallible visitor for the [`MirScalarExpr`]s directly owned by this relation expression.
1688 ///
1689 /// The `f` visitor should not recursively descend into owned [`MirRelationExpr`]s.
1690 pub fn try_visit_scalars_mut1<F, E>(&mut self, f: &mut F) -> Result<(), E>
1691 where
1692 F: FnMut(&mut MirScalarExpr) -> Result<(), E>,
1693 {
1694 use MirRelationExpr::*;
1695 match self {
1696 Map { scalars, .. } => {
1697 for s in scalars {
1698 f(s)?;
1699 }
1700 }
1701 Filter { predicates, .. } => {
1702 for p in predicates {
1703 f(p)?;
1704 }
1705 }
1706 FlatMap { exprs, .. } => {
1707 for expr in exprs {
1708 f(expr)?;
1709 }
1710 }
1711 Join {
1712 inputs: _,
1713 equivalences,
1714 implementation,
1715 } => {
1716 for equivalence in equivalences {
1717 for expr in equivalence {
1718 f(expr)?;
1719 }
1720 }
1721 match implementation {
1722 JoinImplementation::Differential((_, start_key, _), order) => {
1723 if let Some(start_key) = start_key {
1724 for k in start_key {
1725 f(k)?;
1726 }
1727 }
1728 for (_, lookup_key, _) in order {
1729 for k in lookup_key {
1730 f(k)?;
1731 }
1732 }
1733 }
1734 JoinImplementation::DeltaQuery(paths) => {
1735 for path in paths {
1736 for (_, lookup_key, _) in path {
1737 for k in lookup_key {
1738 f(k)?;
1739 }
1740 }
1741 }
1742 }
1743 JoinImplementation::IndexedFilter(_coll_id, _idx_id, index_key, _) => {
1744 for k in index_key {
1745 f(k)?;
1746 }
1747 }
1748 JoinImplementation::Unimplemented => {} // No scalar exprs
1749 }
1750 }
1751 ArrangeBy { keys, .. } => {
1752 for key in keys {
1753 for s in key {
1754 f(s)?;
1755 }
1756 }
1757 }
1758 Reduce {
1759 group_key,
1760 aggregates,
1761 ..
1762 } => {
1763 for s in group_key {
1764 f(s)?;
1765 }
1766 for agg in aggregates {
1767 f(&mut agg.expr)?;
1768 }
1769 }
1770 TopK { limit, .. } => {
1771 if let Some(s) = limit {
1772 f(s)?;
1773 }
1774 }
1775 Constant { .. }
1776 | Get { .. }
1777 | Let { .. }
1778 | LetRec { .. }
1779 | Project { .. }
1780 | Negate { .. }
1781 | Threshold { .. }
1782 | Union { .. } => (),
1783 }
1784 Ok(())
1785 }
1786
1787 /// Fallible mutable visitor for the [`MirScalarExpr`]s in the [`MirRelationExpr`] subtree
1788 /// rooted at `self`.
1789 ///
1790 /// Note that this does not recurse into [`MirRelationExpr`] subtrees within [`MirScalarExpr`]
1791 /// nodes.
1792 pub fn try_visit_scalars_mut<F, E>(&mut self, f: &mut F) -> Result<(), E>
1793 where
1794 F: FnMut(&mut MirScalarExpr) -> Result<(), E>,
1795 E: From<RecursionLimitError>,
1796 {
1797 self.try_visit_mut_post(&mut |expr| expr.try_visit_scalars_mut1(f))
1798 }
1799
1800 /// Infallible mutable visitor for the [`MirScalarExpr`]s in the [`MirRelationExpr`] subtree
1801 /// rooted at `self`.
1802 ///
1803 /// Note that this does not recurse into [`MirRelationExpr`] subtrees within [`MirScalarExpr`]
1804 /// nodes.
1805 pub fn visit_scalars_mut<F>(&mut self, f: &mut F)
1806 where
1807 F: FnMut(&mut MirScalarExpr),
1808 {
1809 self.try_visit_scalars_mut(&mut |s| {
1810 f(s);
1811 Ok::<_, RecursionLimitError>(())
1812 })
1813 .expect("Unexpected error in `visit_scalars_mut` call");
1814 }
1815
1816 /// Fallible visitor for the [`MirScalarExpr`]s directly owned by this relation expression.
1817 ///
1818 /// The `f` visitor should not recursively descend into owned [`MirRelationExpr`]s.
1819 pub fn try_visit_scalars_1<F, E>(&self, f: &mut F) -> Result<(), E>
1820 where
1821 F: FnMut(&MirScalarExpr) -> Result<(), E>,
1822 {
1823 use MirRelationExpr::*;
1824 match self {
1825 Map { scalars, .. } => {
1826 for s in scalars {
1827 f(s)?;
1828 }
1829 }
1830 Filter { predicates, .. } => {
1831 for p in predicates {
1832 f(p)?;
1833 }
1834 }
1835 FlatMap { exprs, .. } => {
1836 for expr in exprs {
1837 f(expr)?;
1838 }
1839 }
1840 Join {
1841 inputs: _,
1842 equivalences,
1843 implementation,
1844 } => {
1845 for equivalence in equivalences {
1846 for expr in equivalence {
1847 f(expr)?;
1848 }
1849 }
1850 match implementation {
1851 JoinImplementation::Differential((_, start_key, _), order) => {
1852 if let Some(start_key) = start_key {
1853 for k in start_key {
1854 f(k)?;
1855 }
1856 }
1857 for (_, lookup_key, _) in order {
1858 for k in lookup_key {
1859 f(k)?;
1860 }
1861 }
1862 }
1863 JoinImplementation::DeltaQuery(paths) => {
1864 for path in paths {
1865 for (_, lookup_key, _) in path {
1866 for k in lookup_key {
1867 f(k)?;
1868 }
1869 }
1870 }
1871 }
1872 JoinImplementation::IndexedFilter(_coll_id, _idx_id, index_key, _) => {
1873 for k in index_key {
1874 f(k)?;
1875 }
1876 }
1877 JoinImplementation::Unimplemented => {} // No scalar exprs
1878 }
1879 }
1880 ArrangeBy { keys, .. } => {
1881 for key in keys {
1882 for s in key {
1883 f(s)?;
1884 }
1885 }
1886 }
1887 Reduce {
1888 group_key,
1889 aggregates,
1890 ..
1891 } => {
1892 for s in group_key {
1893 f(s)?;
1894 }
1895 for agg in aggregates {
1896 f(&agg.expr)?;
1897 }
1898 }
1899 TopK { limit, .. } => {
1900 if let Some(s) = limit {
1901 f(s)?;
1902 }
1903 }
1904 Constant { .. }
1905 | Get { .. }
1906 | Let { .. }
1907 | LetRec { .. }
1908 | Project { .. }
1909 | Negate { .. }
1910 | Threshold { .. }
1911 | Union { .. } => (),
1912 }
1913 Ok(())
1914 }
1915
1916 /// Fallible immutable visitor for the [`MirScalarExpr`]s in the [`MirRelationExpr`] subtree
1917 /// rooted at `self`.
1918 ///
1919 /// Note that this does not recurse into [`MirRelationExpr`] subtrees within [`MirScalarExpr`]
1920 /// nodes.
1921 pub fn try_visit_scalars<F, E>(&self, f: &mut F) -> Result<(), E>
1922 where
1923 F: FnMut(&MirScalarExpr) -> Result<(), E>,
1924 E: From<RecursionLimitError>,
1925 {
1926 self.try_visit_post(&mut |expr| expr.try_visit_scalars_1(f))
1927 }
1928
1929 /// Infallible immutable visitor for the [`MirScalarExpr`]s in the [`MirRelationExpr`] subtree
1930 /// rooted at `self`.
1931 ///
1932 /// Note that this does not recurse into [`MirRelationExpr`] subtrees within [`MirScalarExpr`]
1933 /// nodes.
1934 pub fn visit_scalars<F>(&self, f: &mut F)
1935 where
1936 F: FnMut(&MirScalarExpr),
1937 {
1938 self.try_visit_scalars(&mut |s| {
1939 f(s);
1940 Ok::<_, RecursionLimitError>(())
1941 })
1942 .expect("Unexpected error in `visit_scalars` call");
1943 }
1944
1945 /// Clears the contents of `self` even if it's so deep that simply dropping it would cause a
1946 /// stack overflow in `drop_in_place`.
1947 ///
1948 /// Leaves `self` in an unusable state, so this should only be used if `self` is about to be
1949 /// dropped or otherwise overwritten.
1950 pub fn destroy_carefully(&mut self) {
1951 let mut todo = vec![self.take_dangerous()];
1952 while let Some(mut expr) = todo.pop() {
1953 for child in expr.children_mut() {
1954 todo.push(child.take_dangerous());
1955 }
1956 }
1957 }
1958
1959 /// Computes the size (total number of nodes) and maximum depth of a MirRelationExpr for
1960 /// debug printing purposes.
1961 pub fn debug_size_and_depth(&self) -> (usize, usize) {
1962 let mut size = 0;
1963 let mut max_depth = 0;
1964 let mut todo = vec![(self, 1)];
1965 while let Some((expr, depth)) = todo.pop() {
1966 size += 1;
1967 max_depth = max(max_depth, depth);
1968 todo.extend(expr.children().map(|c| (c, depth + 1)));
1969 }
1970 (size, max_depth)
1971 }
1972
1973 /// The MirRelationExpr is considered potentially expensive if and only if
1974 /// at least one of the following conditions is true:
1975 ///
1976 /// - It contains at least one MirScalarExpr with a function call.
1977 /// - It contains at least one FlatMap or a Reduce operator.
1978 /// - We run into a RecursionLimitError while analyzing the expression.
1979 ///
1980 /// !!!WARNING!!!: this method has an HirRelationExpr counterpart. The two
1981 /// should be kept in sync w.r.t. HIR ⇒ MIR lowering!
1982 pub fn could_run_expensive_function(&self) -> bool {
1983 let mut result = false;
1984 use MirRelationExpr::*;
1985 use MirScalarExpr::*;
1986 if let Err(_) = self.try_visit_scalars::<_, RecursionLimitError>(&mut |scalar| {
1987 result |= match scalar {
1988 Column(_, _) | Literal(_, _) | CallUnmaterializable(_) | If { .. } => false,
1989 // Function calls are considered expensive
1990 CallUnary { .. } | CallBinary { .. } | CallVariadic { .. } => true,
1991 };
1992 Ok(())
1993 }) {
1994 // Conservatively set `true` if on RecursionLimitError.
1995 result = true;
1996 }
1997 self.visit_pre(|e: &MirRelationExpr| {
1998 // FlatMap has a table function; Reduce has an aggregate function.
1999 // Other constructs use MirScalarExpr to run a function
2000 result |= matches!(e, FlatMap { .. } | Reduce { .. });
2001 });
2002 result
2003 }
2004
2005 /// Hash to an u64 using Rust's default Hasher. (Which is a somewhat slower, but better Hasher
2006 /// than what `Hashable::hashed` would give us.)
2007 pub fn hash_to_u64(&self) -> u64 {
2008 let mut h = DefaultHasher::new();
2009 self.hash(&mut h);
2010 h.finish()
2011 }
2012}
2013
2014// `LetRec` helpers
2015impl MirRelationExpr {
2016 /// True when `expr` contains a `LetRec` AST node.
2017 pub fn is_recursive(self: &MirRelationExpr) -> bool {
2018 let mut worklist = vec![self];
2019 while let Some(expr) = worklist.pop() {
2020 if let MirRelationExpr::LetRec { .. } = expr {
2021 return true;
2022 }
2023 worklist.extend(expr.children());
2024 }
2025 false
2026 }
2027
2028 /// Return the number of sub-expressions in the tree (including self).
2029 pub fn size(&self) -> usize {
2030 let mut size = 0;
2031 self.visit_pre(|_| size += 1);
2032 size
2033 }
2034
2035 /// Given the ids and values of a LetRec, it computes the subset of ids that are used across
2036 /// iterations. These are those ids that have a reference before they are defined, when reading
2037 /// all the bindings in order.
2038 ///
2039 /// For example:
2040 /// ```SQL
2041 /// WITH MUTUALLY RECURSIVE
2042 /// x(...) AS f(z),
2043 /// y(...) AS g(x),
2044 /// z(...) AS h(y)
2045 /// ...;
2046 /// ```
2047 /// Here, only `z` is returned, because `x` and `y` are referenced only within the same
2048 /// iteration.
2049 ///
2050 /// Note that if a binding references itself, that is also returned.
2051 pub fn recursive_ids(ids: &[LocalId], values: &[MirRelationExpr]) -> BTreeSet<LocalId> {
2052 let mut used_across_iterations = BTreeSet::new();
2053 let mut defined = BTreeSet::new();
2054 for (binding_id, value) in itertools::zip_eq(ids.iter(), values.iter()) {
2055 value.visit_pre(|expr| {
2056 if let MirRelationExpr::Get {
2057 id: Local(get_id), ..
2058 } = expr
2059 {
2060 // If we haven't seen a definition for it yet, then this will refer
2061 // to the previous iteration.
2062 // The `ids.contains` part of the condition is needed to exclude
2063 // those ids that are not really in this LetRec, but either an inner
2064 // or outer one.
2065 if !defined.contains(get_id) && ids.contains(get_id) {
2066 used_across_iterations.insert(*get_id);
2067 }
2068 }
2069 });
2070 defined.insert(*binding_id);
2071 }
2072 used_across_iterations
2073 }
2074
2075 /// Replaces `LetRec` nodes with a stack of `Let` nodes.
2076 ///
2077 /// In each `Let` binding, uses of `Get` in `value` that are not at strictly greater
2078 /// identifiers are rewritten to be the constant collection.
2079 /// This makes the computation perform exactly "one" iteration.
2080 ///
2081 /// This was used only temporarily while developing `LetRec`.
2082 pub fn make_nonrecursive(self: &mut MirRelationExpr) {
2083 let mut deadlist = BTreeSet::new();
2084 let mut worklist = vec![self];
2085 while let Some(expr) = worklist.pop() {
2086 if let MirRelationExpr::LetRec {
2087 ids,
2088 values,
2089 limits: _,
2090 body,
2091 } = expr
2092 {
2093 let ids_values = values
2094 .drain(..)
2095 .zip_eq(ids)
2096 .map(|(value, id)| (*id, value))
2097 .collect::<Vec<_>>();
2098 *expr = body.take_dangerous();
2099 for (id, mut value) in ids_values.into_iter().rev() {
2100 // Remove references to potentially recursive identifiers.
2101 deadlist.insert(id);
2102 value.visit_pre_mut(|e| {
2103 if let MirRelationExpr::Get {
2104 id: crate::Id::Local(id),
2105 typ,
2106 ..
2107 } = e
2108 {
2109 let typ = typ.clone();
2110 if deadlist.contains(id) {
2111 e.take_safely(Some(typ));
2112 }
2113 }
2114 });
2115 *expr = MirRelationExpr::Let {
2116 id,
2117 value: Box::new(value),
2118 body: Box::new(expr.take_dangerous()),
2119 };
2120 }
2121 worklist.push(expr);
2122 } else {
2123 worklist.extend(expr.children_mut().rev());
2124 }
2125 }
2126 }
2127
2128 /// For each Id `id'` referenced in `expr`, if it is larger or equal than `id`, then record in
2129 /// `expire_whens` that when `id'` is redefined, then we should expire the information that
2130 /// we are holding about `id`. Call `do_expirations` with `expire_whens` at each Id
2131 /// redefinition.
2132 ///
2133 /// IMPORTANT: Relies on the numbering of Ids to be what `renumber_bindings` gives.
2134 pub fn collect_expirations(
2135 id: LocalId,
2136 expr: &MirRelationExpr,
2137 expire_whens: &mut BTreeMap<LocalId, Vec<LocalId>>,
2138 ) {
2139 expr.visit_pre(|e| {
2140 if let MirRelationExpr::Get {
2141 id: Id::Local(referenced_id),
2142 ..
2143 } = e
2144 {
2145 // The following check needs `renumber_bindings` to have run recently
2146 if referenced_id >= &id {
2147 expire_whens
2148 .entry(*referenced_id)
2149 .or_insert_with(Vec::new)
2150 .push(id);
2151 }
2152 }
2153 });
2154 }
2155
2156 /// Call this function when `id` is redefined. It modifies `id_infos` by removing information
2157 /// about such Ids whose information depended on the earlier definition of `id`, according to
2158 /// `expire_whens`. Also modifies `expire_whens`: it removes the currently processed entry.
2159 pub fn do_expirations<I>(
2160 redefined_id: LocalId,
2161 expire_whens: &mut BTreeMap<LocalId, Vec<LocalId>>,
2162 id_infos: &mut BTreeMap<LocalId, I>,
2163 ) -> Vec<(LocalId, I)> {
2164 let mut expired_infos = Vec::new();
2165 if let Some(expirations) = expire_whens.remove(&redefined_id) {
2166 for expired_id in expirations.into_iter() {
2167 if let Some(offer) = id_infos.remove(&expired_id) {
2168 expired_infos.push((expired_id, offer));
2169 }
2170 }
2171 }
2172 expired_infos
2173 }
2174}
2175/// Augment non-nullability of columns, by observing either
2176/// 1. Predicates that explicitly test for null values, and
2177/// 2. Columns that if null would make a predicate be null.
2178pub fn non_nullable_columns(predicates: &[MirScalarExpr]) -> BTreeSet<usize> {
2179 let mut nonnull_required_columns = BTreeSet::new();
2180 for predicate in predicates {
2181 // Add any columns that being null would force the predicate to be null.
2182 // Should that happen, the row would be discarded.
2183 predicate.non_null_requirements(&mut nonnull_required_columns);
2184
2185 /*
2186 Test for explicit checks that a column is non-null.
2187
2188 This analysis is ad hoc, and will miss things:
2189
2190 materialize=> create table a(x int, y int);
2191 CREATE TABLE
2192 materialize=> explain with(types) select x from a where (y=x and y is not null) or x is not null;
2193 Optimized Plan
2194 --------------------------------------------------------------------------------------------------------
2195 Explained Query: +
2196 Project (#0) // { types: "(integer?)" } +
2197 Filter ((#0) IS NOT NULL OR ((#1) IS NOT NULL AND (#0 = #1))) // { types: "(integer?, integer?)" }+
2198 Get materialize.public.a // { types: "(integer?, integer?)" } +
2199 +
2200 Source materialize.public.a +
2201 filter=(((#0) IS NOT NULL OR ((#1) IS NOT NULL AND (#0 = #1)))) +
2202
2203 (1 row)
2204 */
2205
2206 if let MirScalarExpr::CallUnary {
2207 func: UnaryFunc::Not(scalar_func::Not),
2208 expr,
2209 } = predicate
2210 {
2211 if let MirScalarExpr::CallUnary {
2212 func: UnaryFunc::IsNull(scalar_func::IsNull),
2213 expr,
2214 } = &**expr
2215 {
2216 if let MirScalarExpr::Column(c, _name) = &**expr {
2217 nonnull_required_columns.insert(*c);
2218 }
2219 }
2220 }
2221 }
2222
2223 nonnull_required_columns
2224}
2225
2226impl CollectionPlan for MirRelationExpr {
2227 /// Collects the global collections that this MIR expression directly depends on, i.e., that it
2228 /// has a `Get` for. (It does _not_ traverse view definitions transitively.)
2229 ///
2230 /// !!!WARNING!!!: this method has an HirRelationExpr counterpart. The two
2231 /// should be kept in sync w.r.t. HIR ⇒ MIR lowering!
2232 fn depends_on_into(&self, out: &mut BTreeSet<GlobalId>) {
2233 if let MirRelationExpr::Get {
2234 id: Id::Global(id), ..
2235 } = self
2236 {
2237 out.insert(*id);
2238 }
2239 self.visit_children(|expr| expr.depends_on_into(out))
2240 }
2241}
2242
2243impl MirRelationExpr {
2244 /// Iterates through references to child expressions.
2245 pub fn children(&self) -> impl DoubleEndedIterator<Item = &Self> {
2246 let mut first = None;
2247 let mut second = None;
2248 let mut rest = None;
2249 let mut last = None;
2250
2251 use MirRelationExpr::*;
2252 match self {
2253 Constant { .. } | Get { .. } => (),
2254 Let { value, body, .. } => {
2255 first = Some(&**value);
2256 second = Some(&**body);
2257 }
2258 LetRec { values, body, .. } => {
2259 rest = Some(values);
2260 last = Some(&**body);
2261 }
2262 Project { input, .. }
2263 | Map { input, .. }
2264 | FlatMap { input, .. }
2265 | Filter { input, .. }
2266 | Reduce { input, .. }
2267 | TopK { input, .. }
2268 | Negate { input }
2269 | Threshold { input }
2270 | ArrangeBy { input, .. } => {
2271 first = Some(&**input);
2272 }
2273 Join { inputs, .. } => {
2274 rest = Some(inputs);
2275 }
2276 Union { base, inputs } => {
2277 first = Some(&**base);
2278 rest = Some(inputs);
2279 }
2280 }
2281
2282 first
2283 .into_iter()
2284 .chain(second)
2285 .chain(rest.into_iter().flatten())
2286 .chain(last)
2287 }
2288
2289 /// Iterates through mutable references to child expressions.
2290 pub fn children_mut(&mut self) -> impl DoubleEndedIterator<Item = &mut Self> {
2291 let mut first = None;
2292 let mut second = None;
2293 let mut rest = None;
2294 let mut last = None;
2295
2296 use MirRelationExpr::*;
2297 match self {
2298 Constant { .. } | Get { .. } => (),
2299 Let { value, body, .. } => {
2300 first = Some(&mut **value);
2301 second = Some(&mut **body);
2302 }
2303 LetRec { values, body, .. } => {
2304 rest = Some(values);
2305 last = Some(&mut **body);
2306 }
2307 Project { input, .. }
2308 | Map { input, .. }
2309 | FlatMap { input, .. }
2310 | Filter { input, .. }
2311 | Reduce { input, .. }
2312 | TopK { input, .. }
2313 | Negate { input }
2314 | Threshold { input }
2315 | ArrangeBy { input, .. } => {
2316 first = Some(&mut **input);
2317 }
2318 Join { inputs, .. } => {
2319 rest = Some(inputs);
2320 }
2321 Union { base, inputs } => {
2322 first = Some(&mut **base);
2323 rest = Some(inputs);
2324 }
2325 }
2326
2327 first
2328 .into_iter()
2329 .chain(second)
2330 .chain(rest.into_iter().flatten())
2331 .chain(last)
2332 }
2333
2334 /// Iterative pre-order visitor.
2335 pub fn visit_pre<'a, F: FnMut(&'a Self)>(&'a self, mut f: F) {
2336 let mut worklist = vec![self];
2337 while let Some(expr) = worklist.pop() {
2338 f(expr);
2339 worklist.extend(expr.children().rev());
2340 }
2341 }
2342
2343 /// Iterative pre-order visitor.
2344 pub fn visit_pre_mut<F: FnMut(&mut Self)>(&mut self, mut f: F) {
2345 let mut worklist = vec![self];
2346 while let Some(expr) = worklist.pop() {
2347 f(expr);
2348 worklist.extend(expr.children_mut().rev());
2349 }
2350 }
2351
2352 /// Return a vector of references to the subtrees of this expression
2353 /// in post-visit order (the last element is `&self`).
2354 pub fn post_order_vec(&self) -> Vec<&Self> {
2355 let mut stack = vec![self];
2356 let mut result = vec![];
2357 while let Some(expr) = stack.pop() {
2358 result.push(expr);
2359 stack.extend(expr.children());
2360 }
2361 result.reverse();
2362 result
2363 }
2364}
2365
2366impl VisitChildren<Self> for MirRelationExpr {
2367 fn visit_children<F>(&self, mut f: F)
2368 where
2369 F: FnMut(&Self),
2370 {
2371 for child in self.children() {
2372 f(child)
2373 }
2374 }
2375
2376 fn visit_mut_children<F>(&mut self, mut f: F)
2377 where
2378 F: FnMut(&mut Self),
2379 {
2380 for child in self.children_mut() {
2381 f(child)
2382 }
2383 }
2384
2385 fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
2386 where
2387 F: FnMut(&Self) -> Result<(), E>,
2388 E: From<RecursionLimitError>,
2389 {
2390 for child in self.children() {
2391 f(child)?
2392 }
2393 Ok(())
2394 }
2395
2396 fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
2397 where
2398 F: FnMut(&mut Self) -> Result<(), E>,
2399 E: From<RecursionLimitError>,
2400 {
2401 for child in self.children_mut() {
2402 f(child)?
2403 }
2404 Ok(())
2405 }
2406}
2407
2408/// Specification for an ordering by a column.
2409#[derive(
2410 Debug,
2411 Clone,
2412 Copy,
2413 Eq,
2414 PartialEq,
2415 Ord,
2416 PartialOrd,
2417 Serialize,
2418 Deserialize,
2419 Hash,
2420 MzReflect
2421)]
2422pub struct ColumnOrder {
2423 /// The column index.
2424 pub column: usize,
2425 /// Whether to sort in descending order.
2426 #[serde(default)]
2427 pub desc: bool,
2428 /// Whether to sort nulls last.
2429 #[serde(default)]
2430 pub nulls_last: bool,
2431}
2432
2433impl Columnation for ColumnOrder {
2434 type InnerRegion = CopyRegion<Self>;
2435}
2436
2437impl<'a, M> fmt::Display for HumanizedExpr<'a, ColumnOrder, M>
2438where
2439 M: HumanizerMode,
2440{
2441 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
2442 // If you modify this, then please also attend to Display for ColumnOrderWithExpr!
2443 write!(
2444 f,
2445 "{} {} {}",
2446 self.child(&self.expr.column),
2447 if self.expr.desc { "desc" } else { "asc" },
2448 if self.expr.nulls_last {
2449 "nulls_last"
2450 } else {
2451 "nulls_first"
2452 },
2453 )
2454 }
2455}
2456
2457/// Describes an aggregation expression.
2458#[derive(
2459 Clone,
2460 Debug,
2461 Eq,
2462 PartialEq,
2463 Ord,
2464 PartialOrd,
2465 Serialize,
2466 Deserialize,
2467 Hash,
2468 MzReflect
2469)]
2470pub struct AggregateExpr {
2471 /// Names the aggregation function.
2472 pub func: AggregateFunc,
2473 /// An expression which extracts from each row the input to `func`.
2474 pub expr: MirScalarExpr,
2475 /// Should the aggregation be applied only to distinct results in each group.
2476 #[serde(default)]
2477 pub distinct: bool,
2478}
2479
2480impl AggregateExpr {
2481 /// Computes the type of this `AggregateExpr`.
2482 pub fn sql_typ(&self, column_types: &[SqlColumnType]) -> SqlColumnType {
2483 self.func.output_sql_type(self.expr.sql_typ(column_types))
2484 }
2485
2486 /// Computes the type of this `AggregateExpr`.
2487 pub fn typ(&self, column_types: &[ReprColumnType]) -> ReprColumnType {
2488 self.func.output_type(self.expr.typ(column_types))
2489 }
2490
2491 /// Returns whether the expression has a constant result.
2492 pub fn is_constant(&self) -> bool {
2493 match self.func {
2494 AggregateFunc::MaxNumeric
2495 | AggregateFunc::MaxInt16
2496 | AggregateFunc::MaxInt32
2497 | AggregateFunc::MaxInt64
2498 | AggregateFunc::MaxUInt16
2499 | AggregateFunc::MaxUInt32
2500 | AggregateFunc::MaxUInt64
2501 | AggregateFunc::MaxMzTimestamp
2502 | AggregateFunc::MaxFloat32
2503 | AggregateFunc::MaxFloat64
2504 | AggregateFunc::MaxBool
2505 | AggregateFunc::MaxString
2506 | AggregateFunc::MaxDate
2507 | AggregateFunc::MaxTimestamp
2508 | AggregateFunc::MaxTimestampTz
2509 | AggregateFunc::MaxInterval
2510 | AggregateFunc::MaxTime
2511 | AggregateFunc::MinNumeric
2512 | AggregateFunc::MinInt16
2513 | AggregateFunc::MinInt32
2514 | AggregateFunc::MinInt64
2515 | AggregateFunc::MinUInt16
2516 | AggregateFunc::MinUInt32
2517 | AggregateFunc::MinUInt64
2518 | AggregateFunc::MinMzTimestamp
2519 | AggregateFunc::MinFloat32
2520 | AggregateFunc::MinFloat64
2521 | AggregateFunc::MinBool
2522 | AggregateFunc::MinString
2523 | AggregateFunc::MinDate
2524 | AggregateFunc::MinTimestamp
2525 | AggregateFunc::MinTimestampTz
2526 | AggregateFunc::MinInterval
2527 | AggregateFunc::MinTime
2528 | AggregateFunc::Any
2529 | AggregateFunc::All
2530 | AggregateFunc::Dummy => self.expr.is_literal(),
2531 AggregateFunc::Count => self.expr.is_literal_null(),
2532 AggregateFunc::SumInt16
2533 | AggregateFunc::SumInt32
2534 | AggregateFunc::SumInt64
2535 | AggregateFunc::SumUInt16
2536 | AggregateFunc::SumUInt32
2537 | AggregateFunc::SumUInt64
2538 | AggregateFunc::SumFloat32
2539 | AggregateFunc::SumFloat64
2540 | AggregateFunc::SumNumeric
2541 | AggregateFunc::JsonbAgg { .. }
2542 | AggregateFunc::JsonbObjectAgg { .. }
2543 | AggregateFunc::MapAgg { .. }
2544 | AggregateFunc::ArrayConcat { .. }
2545 | AggregateFunc::ListConcat { .. }
2546 | AggregateFunc::StringAgg { .. }
2547 | AggregateFunc::RowNumber { .. }
2548 | AggregateFunc::Rank { .. }
2549 | AggregateFunc::DenseRank { .. }
2550 | AggregateFunc::LagLead { .. }
2551 | AggregateFunc::FirstValue { .. }
2552 | AggregateFunc::LastValue { .. }
2553 | AggregateFunc::FusedValueWindowFunc { .. }
2554 | AggregateFunc::WindowAggregate { .. }
2555 | AggregateFunc::FusedWindowAggregate { .. } => self.expr.is_literal_err(),
2556 }
2557 }
2558
2559 /// Returns an expression that computes `self` on a group that has exactly one row.
2560 /// Instead of performing a `Reduce` with `self`, one can perform a `Map` with the expression
2561 /// returned by `on_unique`, which is cheaper. (See `ReduceElision`.)
2562 pub fn on_unique(&self, input_type: &[ReprColumnType]) -> MirScalarExpr {
2563 match &self.func {
2564 // Count is one if non-null, and zero if null.
2565 AggregateFunc::Count => self
2566 .expr
2567 .clone()
2568 .call_unary(UnaryFunc::IsNull(crate::func::IsNull))
2569 .if_then_else(
2570 MirScalarExpr::literal_ok(Datum::Int64(0), ReprScalarType::Int64),
2571 MirScalarExpr::literal_ok(Datum::Int64(1), ReprScalarType::Int64),
2572 ),
2573
2574 // SumInt16 takes Int16s as input, but outputs Int64s.
2575 AggregateFunc::SumInt16 => self
2576 .expr
2577 .clone()
2578 .call_unary(UnaryFunc::CastInt16ToInt64(scalar_func::CastInt16ToInt64)),
2579
2580 // SumInt32 takes Int32s as input, but outputs Int64s.
2581 AggregateFunc::SumInt32 => self
2582 .expr
2583 .clone()
2584 .call_unary(UnaryFunc::CastInt32ToInt64(scalar_func::CastInt32ToInt64)),
2585
2586 // SumInt64 takes Int64s as input, but outputs numerics.
2587 AggregateFunc::SumInt64 => self.expr.clone().call_unary(UnaryFunc::CastInt64ToNumeric(
2588 scalar_func::CastInt64ToNumeric(Some(NumericMaxScale::ZERO)),
2589 )),
2590
2591 // SumUInt16 takes UInt16s as input, but outputs UInt64s.
2592 AggregateFunc::SumUInt16 => self.expr.clone().call_unary(
2593 UnaryFunc::CastUint16ToUint64(scalar_func::CastUint16ToUint64),
2594 ),
2595
2596 // SumUInt32 takes UInt32s as input, but outputs UInt64s.
2597 AggregateFunc::SumUInt32 => self.expr.clone().call_unary(
2598 UnaryFunc::CastUint32ToUint64(scalar_func::CastUint32ToUint64),
2599 ),
2600
2601 // SumUInt64 takes UInt64s as input, but outputs numerics.
2602 AggregateFunc::SumUInt64 => {
2603 self.expr.clone().call_unary(UnaryFunc::CastUint64ToNumeric(
2604 scalar_func::CastUint64ToNumeric(Some(NumericMaxScale::ZERO)),
2605 ))
2606 }
2607
2608 // JsonbAgg takes _anything_ as input, but must output a Jsonb array.
2609 AggregateFunc::JsonbAgg { .. } => MirScalarExpr::call_variadic(
2610 JsonbBuildArray,
2611 vec![
2612 self.expr
2613 .clone()
2614 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0))),
2615 ],
2616 ),
2617
2618 // JsonbAgg takes _anything_ as input, but must output a Jsonb object.
2619 AggregateFunc::JsonbObjectAgg { .. } => {
2620 let record = self
2621 .expr
2622 .clone()
2623 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2624 MirScalarExpr::call_variadic(
2625 JsonbBuildObject,
2626 (0..2)
2627 .map(|i| {
2628 record
2629 .clone()
2630 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(i)))
2631 })
2632 .collect(),
2633 )
2634 }
2635
2636 AggregateFunc::MapAgg { value_type, .. } => {
2637 let record = self
2638 .expr
2639 .clone()
2640 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2641 MirScalarExpr::call_variadic(
2642 MapBuild {
2643 value_type: value_type.clone(),
2644 },
2645 (0..2)
2646 .map(|i| {
2647 record
2648 .clone()
2649 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(i)))
2650 })
2651 .collect(),
2652 )
2653 }
2654
2655 // StringAgg takes nested records of strings and outputs a string
2656 AggregateFunc::StringAgg { .. } => self
2657 .expr
2658 .clone()
2659 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)))
2660 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0))),
2661
2662 // ListConcat and ArrayConcat take a single level of records and output a list containing exactly 1 element
2663 AggregateFunc::ListConcat { .. } | AggregateFunc::ArrayConcat { .. } => self
2664 .expr
2665 .clone()
2666 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0))),
2667
2668 // RowNumber, Rank, DenseRank take a list of records and output a list containing exactly 1 element
2669 AggregateFunc::RowNumber { .. } => {
2670 self.on_unique_ranking_window_funcs(input_type, "?row_number?")
2671 }
2672 AggregateFunc::Rank { .. } => self.on_unique_ranking_window_funcs(input_type, "?rank?"),
2673 AggregateFunc::DenseRank { .. } => {
2674 self.on_unique_ranking_window_funcs(input_type, "?dense_rank?")
2675 }
2676
2677 // The input type for LagLead is ((OriginalRow, (InputValue, Offset, Default)), OrderByExprs...)
2678 AggregateFunc::LagLead { lag_lead, .. } => {
2679 let tuple = self
2680 .expr
2681 .clone()
2682 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2683
2684 // Get the overall return type
2685 let return_type_with_orig_row = self
2686 .typ(input_type)
2687 .scalar_type
2688 .unwrap_list_element_type()
2689 .clone();
2690 let lag_lead_return_type =
2691 return_type_with_orig_row.unwrap_record_element_type()[0].clone();
2692
2693 // Extract the original row
2694 let original_row = tuple
2695 .clone()
2696 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2697
2698 // Extract the encoded args
2699 let encoded_args =
2700 tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
2701
2702 let (result_expr, column_name) =
2703 Self::on_unique_lag_lead(lag_lead, encoded_args, lag_lead_return_type.clone());
2704
2705 MirScalarExpr::call_variadic(
2706 ListCreate {
2707 elem_type: SqlScalarType::from_repr(&return_type_with_orig_row),
2708 },
2709 vec![MirScalarExpr::call_variadic(
2710 RecordCreate {
2711 field_names: vec![column_name, ColumnName::from("?record?")],
2712 },
2713 vec![result_expr, original_row],
2714 )],
2715 )
2716 }
2717
2718 // The input type for FirstValue is ((OriginalRow, InputValue), OrderByExprs...)
2719 AggregateFunc::FirstValue { window_frame, .. } => {
2720 let tuple = self
2721 .expr
2722 .clone()
2723 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2724
2725 // Get the overall return type
2726 let return_type_with_orig_row = self
2727 .typ(input_type)
2728 .scalar_type
2729 .unwrap_list_element_type()
2730 .clone();
2731 let first_value_return_type =
2732 return_type_with_orig_row.unwrap_record_element_type()[0].clone();
2733
2734 // Extract the original row
2735 let original_row = tuple
2736 .clone()
2737 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2738
2739 // Extract the input value
2740 let arg = tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
2741
2742 let (result_expr, column_name) = Self::on_unique_first_value_last_value(
2743 window_frame,
2744 arg,
2745 first_value_return_type,
2746 );
2747
2748 MirScalarExpr::call_variadic(
2749 ListCreate {
2750 elem_type: SqlScalarType::from_repr(&return_type_with_orig_row),
2751 },
2752 vec![MirScalarExpr::call_variadic(
2753 RecordCreate {
2754 field_names: vec![column_name, ColumnName::from("?record?")],
2755 },
2756 vec![result_expr, original_row],
2757 )],
2758 )
2759 }
2760
2761 // The input type for LastValue is ((OriginalRow, InputValue), OrderByExprs...)
2762 AggregateFunc::LastValue { window_frame, .. } => {
2763 let tuple = self
2764 .expr
2765 .clone()
2766 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2767
2768 // Get the overall return type
2769 let return_type_with_orig_row = self
2770 .typ(input_type)
2771 .scalar_type
2772 .unwrap_list_element_type()
2773 .clone();
2774 let last_value_return_type =
2775 return_type_with_orig_row.unwrap_record_element_type()[0].clone();
2776
2777 // Extract the original row
2778 let original_row = tuple
2779 .clone()
2780 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2781
2782 // Extract the input value
2783 let arg = tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
2784
2785 let (result_expr, column_name) = Self::on_unique_first_value_last_value(
2786 window_frame,
2787 arg,
2788 last_value_return_type,
2789 );
2790
2791 MirScalarExpr::call_variadic(
2792 ListCreate {
2793 elem_type: SqlScalarType::from_repr(&return_type_with_orig_row),
2794 },
2795 vec![MirScalarExpr::call_variadic(
2796 RecordCreate {
2797 field_names: vec![column_name, ColumnName::from("?record?")],
2798 },
2799 vec![result_expr, original_row],
2800 )],
2801 )
2802 }
2803
2804 // The input type for window aggs is ((OriginalRow, InputValue), OrderByExprs...)
2805 // See an example MIR in `window_func_applied_to`.
2806 AggregateFunc::WindowAggregate {
2807 wrapped_aggregate,
2808 window_frame,
2809 order_by: _,
2810 } => {
2811 // TODO: deduplicate code between the various window function cases.
2812
2813 let tuple = self
2814 .expr
2815 .clone()
2816 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2817
2818 // Get the overall return type
2819 let return_type = self
2820 .typ(input_type)
2821 .scalar_type
2822 .unwrap_list_element_type()
2823 .clone();
2824 let window_agg_return_type = return_type.unwrap_record_element_type()[0].clone();
2825
2826 // Extract the original row
2827 let original_row = tuple
2828 .clone()
2829 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2830
2831 // Extract the input value
2832 let arg_expr = tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
2833
2834 let (result, column_name) = Self::on_unique_window_agg(
2835 window_frame,
2836 arg_expr,
2837 input_type,
2838 window_agg_return_type,
2839 wrapped_aggregate,
2840 );
2841
2842 MirScalarExpr::call_variadic(
2843 ListCreate {
2844 elem_type: SqlScalarType::from_repr(&return_type),
2845 },
2846 vec![MirScalarExpr::call_variadic(
2847 RecordCreate {
2848 field_names: vec![column_name, ColumnName::from("?record?")],
2849 },
2850 vec![result, original_row],
2851 )],
2852 )
2853 }
2854
2855 // The input type is ((OriginalRow, (Arg1, Arg2, ...)), OrderByExprs...)
2856 AggregateFunc::FusedWindowAggregate {
2857 wrapped_aggregates,
2858 order_by: _,
2859 window_frame,
2860 } => {
2861 // Throw away OrderByExprs
2862 let tuple = self
2863 .expr
2864 .clone()
2865 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2866
2867 // Extract the original row
2868 let original_row = tuple
2869 .clone()
2870 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2871
2872 // Extract the args of the fused call
2873 let all_args = tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
2874
2875 let return_type_with_orig_row = self
2876 .typ(input_type)
2877 .scalar_type
2878 .unwrap_list_element_type()
2879 .clone();
2880
2881 let all_func_return_types =
2882 return_type_with_orig_row.unwrap_record_element_type()[0].clone();
2883 let mut func_result_exprs = Vec::new();
2884 let mut col_names = Vec::new();
2885 for (idx, wrapped_aggr) in wrapped_aggregates.iter().enumerate() {
2886 let arg = all_args
2887 .clone()
2888 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(idx)));
2889 let return_type =
2890 all_func_return_types.unwrap_record_element_type()[idx].clone();
2891 let (result, column_name) = Self::on_unique_window_agg(
2892 window_frame,
2893 arg,
2894 input_type,
2895 return_type,
2896 wrapped_aggr,
2897 );
2898 func_result_exprs.push(result);
2899 col_names.push(column_name);
2900 }
2901
2902 MirScalarExpr::call_variadic(
2903 ListCreate {
2904 elem_type: SqlScalarType::from_repr(&return_type_with_orig_row),
2905 },
2906 vec![MirScalarExpr::call_variadic(
2907 RecordCreate {
2908 field_names: vec![
2909 ColumnName::from("?fused_window_aggr?"),
2910 ColumnName::from("?record?"),
2911 ],
2912 },
2913 vec![
2914 MirScalarExpr::call_variadic(
2915 RecordCreate {
2916 field_names: col_names,
2917 },
2918 func_result_exprs,
2919 ),
2920 original_row,
2921 ],
2922 )],
2923 )
2924 }
2925
2926 // The input type is ((OriginalRow, (Args1, Args2, ...)), OrderByExprs...)
2927 AggregateFunc::FusedValueWindowFunc {
2928 funcs,
2929 order_by: outer_order_by,
2930 } => {
2931 // Throw away OrderByExprs
2932 let tuple = self
2933 .expr
2934 .clone()
2935 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2936
2937 // Extract the original row
2938 let original_row = tuple
2939 .clone()
2940 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2941
2942 // Extract the encoded args of the fused call
2943 let all_encoded_args =
2944 tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
2945
2946 let return_type_with_orig_row = self
2947 .typ(input_type)
2948 .scalar_type
2949 .unwrap_list_element_type()
2950 .clone();
2951
2952 let all_func_return_types =
2953 return_type_with_orig_row.unwrap_record_element_type()[0].clone();
2954 let mut func_result_exprs = Vec::new();
2955 let mut col_names = Vec::new();
2956 for (idx, func) in funcs.iter().enumerate() {
2957 let args_for_func = all_encoded_args
2958 .clone()
2959 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(idx)));
2960 let return_type_for_func =
2961 all_func_return_types.unwrap_record_element_type()[idx].clone();
2962 let (result, column_name) = match func {
2963 AggregateFunc::LagLead {
2964 lag_lead,
2965 order_by,
2966 ignore_nulls: _,
2967 } => {
2968 assert_eq!(order_by, outer_order_by);
2969 Self::on_unique_lag_lead(lag_lead, args_for_func, return_type_for_func)
2970 }
2971 AggregateFunc::FirstValue {
2972 window_frame,
2973 order_by,
2974 } => {
2975 assert_eq!(order_by, outer_order_by);
2976 Self::on_unique_first_value_last_value(
2977 window_frame,
2978 args_for_func,
2979 return_type_for_func,
2980 )
2981 }
2982 AggregateFunc::LastValue {
2983 window_frame,
2984 order_by,
2985 } => {
2986 assert_eq!(order_by, outer_order_by);
2987 Self::on_unique_first_value_last_value(
2988 window_frame,
2989 args_for_func,
2990 return_type_for_func,
2991 )
2992 }
2993 _ => panic!("unknown function in FusedValueWindowFunc"),
2994 };
2995 func_result_exprs.push(result);
2996 col_names.push(column_name);
2997 }
2998
2999 MirScalarExpr::call_variadic(
3000 ListCreate {
3001 elem_type: SqlScalarType::from_repr(&return_type_with_orig_row),
3002 },
3003 vec![MirScalarExpr::call_variadic(
3004 RecordCreate {
3005 field_names: vec![
3006 ColumnName::from("?fused_value_window_func?"),
3007 ColumnName::from("?record?"),
3008 ],
3009 },
3010 vec![
3011 MirScalarExpr::call_variadic(
3012 RecordCreate {
3013 field_names: col_names,
3014 },
3015 func_result_exprs,
3016 ),
3017 original_row,
3018 ],
3019 )],
3020 )
3021 }
3022
3023 // All other variants should return the argument to the aggregation.
3024 AggregateFunc::MaxNumeric
3025 | AggregateFunc::MaxInt16
3026 | AggregateFunc::MaxInt32
3027 | AggregateFunc::MaxInt64
3028 | AggregateFunc::MaxUInt16
3029 | AggregateFunc::MaxUInt32
3030 | AggregateFunc::MaxUInt64
3031 | AggregateFunc::MaxMzTimestamp
3032 | AggregateFunc::MaxFloat32
3033 | AggregateFunc::MaxFloat64
3034 | AggregateFunc::MaxBool
3035 | AggregateFunc::MaxString
3036 | AggregateFunc::MaxDate
3037 | AggregateFunc::MaxTimestamp
3038 | AggregateFunc::MaxTimestampTz
3039 | AggregateFunc::MaxInterval
3040 | AggregateFunc::MaxTime
3041 | AggregateFunc::MinNumeric
3042 | AggregateFunc::MinInt16
3043 | AggregateFunc::MinInt32
3044 | AggregateFunc::MinInt64
3045 | AggregateFunc::MinUInt16
3046 | AggregateFunc::MinUInt32
3047 | AggregateFunc::MinUInt64
3048 | AggregateFunc::MinMzTimestamp
3049 | AggregateFunc::MinFloat32
3050 | AggregateFunc::MinFloat64
3051 | AggregateFunc::MinBool
3052 | AggregateFunc::MinString
3053 | AggregateFunc::MinDate
3054 | AggregateFunc::MinTimestamp
3055 | AggregateFunc::MinTimestampTz
3056 | AggregateFunc::MinInterval
3057 | AggregateFunc::MinTime
3058 | AggregateFunc::SumFloat32
3059 | AggregateFunc::SumFloat64
3060 | AggregateFunc::SumNumeric
3061 | AggregateFunc::Any
3062 | AggregateFunc::All
3063 | AggregateFunc::Dummy => self.expr.clone(),
3064 }
3065 }
3066
3067 /// `on_unique` for ROW_NUMBER, RANK, DENSE_RANK
3068 fn on_unique_ranking_window_funcs(
3069 &self,
3070 input_type: &[ReprColumnType],
3071 col_name: &str,
3072 ) -> MirScalarExpr {
3073 let sql_input_type: Vec<SqlColumnType> =
3074 input_type.iter().map(SqlColumnType::from_repr).collect();
3075 let list = self
3076 .expr
3077 .clone()
3078 // extract the list within the record
3079 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
3080
3081 // extract the expression within the list
3082 let record = MirScalarExpr::call_variadic(
3083 ListIndex,
3084 vec![
3085 list,
3086 MirScalarExpr::literal_ok(Datum::Int64(1), ReprScalarType::Int64),
3087 ],
3088 );
3089
3090 MirScalarExpr::call_variadic(
3091 ListCreate {
3092 elem_type: self
3093 .sql_typ(&sql_input_type)
3094 .scalar_type
3095 .unwrap_list_element_type()
3096 .clone(),
3097 },
3098 vec![MirScalarExpr::call_variadic(
3099 RecordCreate {
3100 field_names: vec![ColumnName::from(col_name), ColumnName::from("?record?")],
3101 },
3102 vec![
3103 MirScalarExpr::literal_ok(Datum::Int64(1), ReprScalarType::Int64),
3104 record,
3105 ],
3106 )],
3107 )
3108 }
3109
3110 /// `on_unique` for `lag` and `lead`
3111 fn on_unique_lag_lead(
3112 lag_lead: &LagLeadType,
3113 encoded_args: MirScalarExpr,
3114 return_type: ReprScalarType,
3115 ) -> (MirScalarExpr, ColumnName) {
3116 let expr = encoded_args
3117 .clone()
3118 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
3119 let offset = encoded_args
3120 .clone()
3121 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
3122 let default_value =
3123 encoded_args.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(2)));
3124
3125 // In this case, the window always has only one element, so if the offset is not null and
3126 // not zero, the default value should be returned instead.
3127 let value = offset
3128 .clone()
3129 .call_binary(
3130 MirScalarExpr::literal_ok(Datum::Int32(0), ReprScalarType::Int32),
3131 crate::func::Eq,
3132 )
3133 .if_then_else(expr, default_value);
3134 let result_expr = offset
3135 .call_unary(UnaryFunc::IsNull(crate::func::IsNull))
3136 .if_then_else(MirScalarExpr::literal_null(return_type), value);
3137
3138 let column_name = ColumnName::from(match lag_lead {
3139 LagLeadType::Lag => "?lag?",
3140 LagLeadType::Lead => "?lead?",
3141 });
3142
3143 (result_expr, column_name)
3144 }
3145
3146 /// `on_unique` for `first_value` and `last_value`
3147 fn on_unique_first_value_last_value(
3148 window_frame: &WindowFrame,
3149 arg: MirScalarExpr,
3150 return_type: ReprScalarType,
3151 ) -> (MirScalarExpr, ColumnName) {
3152 // If the window frame includes the current (single) row, return its value, null otherwise
3153 let result_expr = if window_frame.includes_current_row() {
3154 arg
3155 } else {
3156 MirScalarExpr::literal_null(return_type)
3157 };
3158 (result_expr, ColumnName::from("?first_value?"))
3159 }
3160
3161 /// `on_unique` for window aggregations
3162 fn on_unique_window_agg(
3163 window_frame: &WindowFrame,
3164 arg_expr: MirScalarExpr,
3165 input_type: &[ReprColumnType],
3166 return_type: ReprScalarType,
3167 wrapped_aggr: &AggregateFunc,
3168 ) -> (MirScalarExpr, ColumnName) {
3169 // If the window frame includes the current (single) row, evaluate the wrapped aggregate on
3170 // that row. Otherwise, return the default value for the aggregate.
3171 let result_expr = if window_frame.includes_current_row() {
3172 AggregateExpr {
3173 func: wrapped_aggr.clone(),
3174 expr: arg_expr,
3175 distinct: false, // We have just one input element; DISTINCT doesn't matter.
3176 }
3177 .on_unique(input_type)
3178 } else {
3179 MirScalarExpr::literal_ok(wrapped_aggr.default(), return_type)
3180 };
3181 (result_expr, ColumnName::from("?window_agg?"))
3182 }
3183
3184 /// Returns whether the expression is COUNT(*) or not. Note that
3185 /// when we define the count builtin in sql::func, we convert
3186 /// COUNT(*) to COUNT(true), making it indistinguishable from
3187 /// literal COUNT(true), but we prefer to consider this as the
3188 /// former.
3189 ///
3190 /// (HIR has the same `is_count_asterisk`.)
3191 pub fn is_count_asterisk(&self) -> bool {
3192 self.func == AggregateFunc::Count && self.expr.is_literal_true() && !self.distinct
3193 }
3194}
3195
3196/// Describe a join implementation in dataflow.
3197#[derive(
3198 Clone,
3199 Debug,
3200 Eq,
3201 PartialEq,
3202 Ord,
3203 PartialOrd,
3204 Serialize,
3205 Deserialize,
3206 Hash,
3207 MzReflect
3208)]
3209pub enum JoinImplementation {
3210 /// Perform a sequence of binary differential dataflow joins.
3211 ///
3212 /// The first argument indicates
3213 /// 1) the index of the starting collection,
3214 /// 2) if it should be arranged, the keys to arrange it by, and
3215 /// 3) the characteristics of the starting collection (for EXPLAINing).
3216 /// The sequence that follows lists other relation indexes, and the key for
3217 /// the arrangement we should use when joining it in.
3218 /// The JoinInputCharacteristics are for EXPLAINing the characteristics that
3219 /// were used for join ordering.
3220 ///
3221 /// Each collection index should occur exactly once, either as the starting collection
3222 /// or somewhere in the list.
3223 Differential(
3224 (
3225 usize,
3226 Option<Vec<MirScalarExpr>>,
3227 Option<JoinInputCharacteristics>,
3228 ),
3229 Vec<(usize, Vec<MirScalarExpr>, Option<JoinInputCharacteristics>)>,
3230 ),
3231 /// Perform independent delta query dataflows for each input.
3232 ///
3233 /// The argument is a sequence of plans, for the input collections in order.
3234 /// Each plan starts from the corresponding index, and then in sequence joins
3235 /// against collections identified by index and with the specified arrangement key.
3236 /// The JoinInputCharacteristics are for EXPLAINing the characteristics that were
3237 /// used for join ordering.
3238 DeltaQuery(Vec<Vec<(usize, Vec<MirScalarExpr>, Option<JoinInputCharacteristics>)>>),
3239 /// Join a user-created index with a constant collection to speed up the evaluation of a
3240 /// predicate such as `(f1 = 3 AND f2 = 5) OR (f1 = 7 AND f2 = 9)`.
3241 /// This gets translated to a Differential join during MIR -> LIR lowering, but we still want
3242 /// to represent it in MIR, because the fast path detection wants to match on this.
3243 ///
3244 /// Consists of (`<coll_id>`, `<index_id>`, `<index_key>`, `<constants>`)
3245 IndexedFilter(
3246 GlobalId,
3247 GlobalId,
3248 Vec<MirScalarExpr>,
3249 #[mzreflect(ignore)] Vec<Row>,
3250 ),
3251 /// No implementation yet selected.
3252 Unimplemented,
3253}
3254
3255impl Default for JoinImplementation {
3256 fn default() -> Self {
3257 JoinImplementation::Unimplemented
3258 }
3259}
3260
3261impl JoinImplementation {
3262 /// Returns `true` iff the value is not [`JoinImplementation::Unimplemented`].
3263 pub fn is_implemented(&self) -> bool {
3264 match self {
3265 Self::Unimplemented => false,
3266 _ => true,
3267 }
3268 }
3269
3270 /// Returns an optional implementation name if the value is not [`JoinImplementation::Unimplemented`].
3271 pub fn name(&self) -> Option<&'static str> {
3272 match self {
3273 Self::Differential(..) => Some("differential"),
3274 Self::DeltaQuery(..) => Some("delta"),
3275 Self::IndexedFilter(..) => Some("indexed_filter"),
3276 Self::Unimplemented => None,
3277 }
3278 }
3279}
3280
3281/// Characteristics of a join order candidate collection.
3282///
3283/// A candidate is described by a collection and a key, and may have various liabilities.
3284/// Primarily, the candidate may risk substantial inflation of records, which is something
3285/// that concerns us greatly. Additionally, the candidate may be unarranged, and we would
3286/// prefer candidates that do not require additional memory. Finally, we prefer lower id
3287/// collections in the interest of consistent tie-breaking. For more characteristics, see
3288/// comments on individual fields.
3289///
3290/// This has more than one version. `new` instantiates the appropriate version based on a
3291/// feature flag.
3292#[derive(
3293 Eq,
3294 PartialEq,
3295 Ord,
3296 PartialOrd,
3297 Debug,
3298 Clone,
3299 Serialize,
3300 Deserialize,
3301 Hash,
3302 MzReflect
3303)]
3304pub enum JoinInputCharacteristics {
3305 /// Old version, with `enable_join_prioritize_arranged` turned off.
3306 V1(JoinInputCharacteristicsV1),
3307 /// Newer version, with `enable_join_prioritize_arranged` turned on.
3308 V2(JoinInputCharacteristicsV2),
3309}
3310
3311impl JoinInputCharacteristics {
3312 /// Creates a new instance with the given characteristics.
3313 pub fn new(
3314 unique_key: bool,
3315 key_length: usize,
3316 arranged: bool,
3317 cardinality: Option<usize>,
3318 filters: FilterCharacteristics,
3319 input: usize,
3320 enable_join_prioritize_arranged: bool,
3321 ) -> Self {
3322 if enable_join_prioritize_arranged {
3323 Self::V2(JoinInputCharacteristicsV2::new(
3324 unique_key,
3325 key_length,
3326 arranged,
3327 cardinality,
3328 filters,
3329 input,
3330 ))
3331 } else {
3332 Self::V1(JoinInputCharacteristicsV1::new(
3333 unique_key,
3334 key_length,
3335 arranged,
3336 cardinality,
3337 filters,
3338 input,
3339 ))
3340 }
3341 }
3342
3343 /// Turns the instance into a String to be printed in EXPLAIN.
3344 pub fn explain(&self) -> String {
3345 match self {
3346 Self::V1(jic) => jic.explain(),
3347 Self::V2(jic) => jic.explain(),
3348 }
3349 }
3350
3351 /// Whether the join input described by `self` is arranged.
3352 pub fn arranged(&self) -> bool {
3353 match self {
3354 Self::V1(jic) => jic.arranged,
3355 Self::V2(jic) => jic.arranged,
3356 }
3357 }
3358
3359 /// Returns the `FilterCharacteristics` for the join input described by `self`.
3360 pub fn filters(&mut self) -> &mut FilterCharacteristics {
3361 match self {
3362 Self::V1(jic) => &mut jic.filters,
3363 Self::V2(jic) => &mut jic.filters,
3364 }
3365 }
3366}
3367
3368/// Newer version of `JoinInputCharacteristics`, with `enable_join_prioritize_arranged` turned on.
3369#[derive(
3370 Eq,
3371 PartialEq,
3372 Ord,
3373 PartialOrd,
3374 Debug,
3375 Clone,
3376 Serialize,
3377 Deserialize,
3378 Hash,
3379 MzReflect
3380)]
3381pub struct JoinInputCharacteristicsV2 {
3382 /// An excellent indication that record count will not increase.
3383 pub unique_key: bool,
3384 /// Cross joins are bad.
3385 /// (`key_length > 0` also implies that it is not a cross join. However, we need to note cross
3386 /// joins in a separate field, because not being a cross join is more important than `arranged`,
3387 /// but otherwise `key_length` is less important than `arranged`.)
3388 pub not_cross: bool,
3389 /// Indicates that there will be no additional in-memory footprint.
3390 pub arranged: bool,
3391 /// A weaker signal that record count will not increase.
3392 pub key_length: usize,
3393 /// Estimated cardinality (lower is better)
3394 pub cardinality: Option<std::cmp::Reverse<usize>>,
3395 /// Characteristics of the filter that is applied at this input.
3396 pub filters: FilterCharacteristics,
3397 /// We want to prefer input earlier in the input list, for stability of ordering.
3398 pub input: std::cmp::Reverse<usize>,
3399}
3400
3401impl JoinInputCharacteristicsV2 {
3402 /// Creates a new instance with the given characteristics.
3403 pub fn new(
3404 unique_key: bool,
3405 key_length: usize,
3406 arranged: bool,
3407 cardinality: Option<usize>,
3408 filters: FilterCharacteristics,
3409 input: usize,
3410 ) -> Self {
3411 Self {
3412 unique_key,
3413 not_cross: key_length > 0,
3414 arranged,
3415 key_length,
3416 cardinality: cardinality.map(std::cmp::Reverse),
3417 filters,
3418 input: std::cmp::Reverse(input),
3419 }
3420 }
3421
3422 /// Turns the instance into a String to be printed in EXPLAIN.
3423 pub fn explain(&self) -> String {
3424 let mut e = "".to_owned();
3425 if self.unique_key {
3426 e.push_str("U");
3427 }
3428 // Don't need to print `not_cross`, because that is visible in the printed key.
3429 // if !self.not_cross {
3430 // e.push_str("C");
3431 // }
3432 for _ in 0..self.key_length {
3433 e.push_str("K");
3434 }
3435 if self.arranged {
3436 e.push_str("A");
3437 }
3438 if let Some(std::cmp::Reverse(cardinality)) = self.cardinality {
3439 e.push_str(&format!("|{cardinality}|"));
3440 }
3441 e.push_str(&self.filters.explain());
3442 e
3443 }
3444}
3445
3446/// Old version of `JoinInputCharacteristics`, with `enable_join_prioritize_arranged` turned off.
3447#[derive(
3448 Eq,
3449 PartialEq,
3450 Ord,
3451 PartialOrd,
3452 Debug,
3453 Clone,
3454 Serialize,
3455 Deserialize,
3456 Hash,
3457 MzReflect
3458)]
3459pub struct JoinInputCharacteristicsV1 {
3460 /// An excellent indication that record count will not increase.
3461 pub unique_key: bool,
3462 /// A weaker signal that record count will not increase.
3463 pub key_length: usize,
3464 /// Indicates that there will be no additional in-memory footprint.
3465 pub arranged: bool,
3466 /// Estimated cardinality (lower is better)
3467 pub cardinality: Option<std::cmp::Reverse<usize>>,
3468 /// Characteristics of the filter that is applied at this input.
3469 pub filters: FilterCharacteristics,
3470 /// We want to prefer input earlier in the input list, for stability of ordering.
3471 pub input: std::cmp::Reverse<usize>,
3472}
3473
3474impl JoinInputCharacteristicsV1 {
3475 /// Creates a new instance with the given characteristics.
3476 pub fn new(
3477 unique_key: bool,
3478 key_length: usize,
3479 arranged: bool,
3480 cardinality: Option<usize>,
3481 filters: FilterCharacteristics,
3482 input: usize,
3483 ) -> Self {
3484 Self {
3485 unique_key,
3486 key_length,
3487 arranged,
3488 cardinality: cardinality.map(std::cmp::Reverse),
3489 filters,
3490 input: std::cmp::Reverse(input),
3491 }
3492 }
3493
3494 /// Turns the instance into a String to be printed in EXPLAIN.
3495 pub fn explain(&self) -> String {
3496 let mut e = "".to_owned();
3497 if self.unique_key {
3498 e.push_str("U");
3499 }
3500 for _ in 0..self.key_length {
3501 e.push_str("K");
3502 }
3503 if self.arranged {
3504 e.push_str("A");
3505 }
3506 if let Some(std::cmp::Reverse(cardinality)) = self.cardinality {
3507 e.push_str(&format!("|{cardinality}|"));
3508 }
3509 e.push_str(&self.filters.explain());
3510 e
3511 }
3512}
3513
3514/// Instructions for finishing the result of a query.
3515///
3516/// The primary reason for the existence of this structure and attendant code
3517/// is that SQL's ORDER BY requires sorting rows (as already implied by the
3518/// keywords), whereas much of the rest of SQL is defined in terms of unordered
3519/// multisets. But as it turns out, the same idea can be used to optimize
3520/// trivial peeks.
3521///
3522/// The generic parameters are for accommodating prepared statement parameters in
3523/// `limit` and `offset`: the planner can hold these fields as HirScalarExpr long enough to call
3524/// `bind_parameters` on them.
3525#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
3526pub struct RowSetFinishing<L = NonNeg<i64>, O = usize> {
3527 /// Order rows by the given columns.
3528 pub order_by: Vec<ColumnOrder>,
3529 /// Include only as many rows (after offset).
3530 pub limit: Option<L>,
3531 /// Omit as many rows.
3532 pub offset: O,
3533 /// Include only given columns.
3534 pub project: Vec<usize>,
3535}
3536
3537impl<L> RowSetFinishing<L> {
3538 /// Returns a trivial finishing, i.e., that does nothing to the result set.
3539 pub fn trivial(arity: usize) -> RowSetFinishing<L> {
3540 RowSetFinishing {
3541 order_by: Vec::new(),
3542 limit: None,
3543 offset: 0,
3544 project: (0..arity).collect(),
3545 }
3546 }
3547 /// True if the finishing does nothing to any result set.
3548 pub fn is_trivial(&self, arity: usize) -> bool {
3549 self.limit.is_none()
3550 && self.order_by.is_empty()
3551 && self.offset == 0
3552 && self.project.iter().copied().eq(0..arity)
3553 }
3554 /// True if the finishing does not require an ORDER BY.
3555 ///
3556 /// LIMIT and OFFSET without an ORDER BY _are_ streamable: without an
3557 /// explicit ordering we will skip an arbitrary bag of elements and return
3558 /// the first arbitrary elements in the remaining bag. The result semantics
3559 /// are still correct but maybe surprising for some users.
3560 pub fn is_streamable(&self, arity: usize) -> bool {
3561 self.order_by.is_empty() && self.project.iter().copied().eq(0..arity)
3562 }
3563}
3564
3565impl RowSetFinishing<NonNeg<i64>, usize> {
3566 /// The number of rows needed from before the finishing to evaluate the finishing:
3567 /// offset + limit.
3568 ///
3569 /// If it returns None, then we need all the rows.
3570 pub fn num_rows_needed(&self) -> Option<usize> {
3571 self.limit
3572 .as_ref()
3573 .map(|l| usize::cast_from(u64::from(l.clone())) + self.offset)
3574 }
3575}
3576
3577impl RowSetFinishing {
3578 /// Applies finishing actions to a [`RowCollection`], and reports the total
3579 /// time it took to run.
3580 ///
3581 /// Returns a [`RowCollectionIter`] that contains all of the response data, as
3582 /// well as the size of the response in bytes.
3583 pub fn finish(
3584 &self,
3585 rows: RowCollection,
3586 max_result_size: u64,
3587 max_returned_query_size: Option<u64>,
3588 duration_histogram: &Histogram,
3589 ) -> Result<(RowCollectionIter, usize), String> {
3590 let now = Instant::now();
3591 let result = self.finish_inner(rows, max_result_size, max_returned_query_size);
3592 let duration = now.elapsed();
3593 duration_histogram.observe(duration.as_secs_f64());
3594
3595 result
3596 }
3597
3598 /// Implementation for [`RowSetFinishing::finish`].
3599 fn finish_inner(
3600 &self,
3601 rows: RowCollection,
3602 max_result_size: u64,
3603 max_returned_query_size: Option<u64>,
3604 ) -> Result<(RowCollectionIter, usize), String> {
3605 // How much additional memory is required to make a sorted view.
3606 let sorted_view_mem = rows.entries().saturating_mul(std::mem::size_of::<usize>());
3607 let required_memory = rows.byte_len().saturating_add(sorted_view_mem);
3608
3609 // Bail if creating the sorted view would require us to use too much memory.
3610 if required_memory > usize::cast_from(max_result_size) {
3611 let max_bytes = ByteSize::b(max_result_size);
3612 return Err(format!("result exceeds max size of {max_bytes}",));
3613 }
3614
3615 let sorted_view = rows;
3616 let mut iter = sorted_view
3617 .into_row_iter()
3618 .apply_offset(self.offset)
3619 .with_projection(self.project.clone());
3620
3621 if let Some(limit) = self.limit {
3622 let limit = u64::from(limit);
3623 let limit = usize::cast_from(limit);
3624 iter = iter.with_limit(limit);
3625 };
3626
3627 // TODO(parkmycar): Re-think how we can calculate the total response size without
3628 // having to iterate through the entire collection of Rows, while still
3629 // respecting the LIMIT, OFFSET, and projections.
3630 //
3631 // Note: It feels a bit bad always calculating the response size, but we almost
3632 // always need it to either check the `max_returned_query_size`, or for reporting
3633 // in the query history.
3634 let response_size: usize = iter.clone().map(|row| row.data().len()).sum();
3635
3636 // Bail if we would end up returning more data to the client than they can support.
3637 if let Some(max) = max_returned_query_size {
3638 if response_size > usize::cast_from(max) {
3639 let max_bytes = ByteSize::b(max);
3640 return Err(format!("result exceeds max size of {max_bytes}"));
3641 }
3642 }
3643
3644 Ok((iter, response_size))
3645 }
3646}
3647
3648/// A [RowSetFinishing] that can be repeatedly applied to batches of updates (in
3649/// a [RowCollection]) and keeps track of the remaining limit, offset, and cap
3650/// on query result size.
3651#[derive(Debug)]
3652pub struct RowSetFinishingIncremental {
3653 /// Include only as many rows (after offset).
3654 pub remaining_limit: Option<usize>,
3655 /// Omit as many rows.
3656 pub remaining_offset: usize,
3657 /// The maximum allowed result size, as requested by the client.
3658 pub max_returned_query_size: Option<u64>,
3659 /// Tracks our remaining allowed budget for result size.
3660 pub remaining_max_returned_query_size: Option<u64>,
3661 /// Include only given columns.
3662 pub project: Vec<usize>,
3663}
3664
3665impl RowSetFinishingIncremental {
3666 /// Turns the given [RowSetFinishing] into a [RowSetFinishingIncremental].
3667 /// Can only be used when [is_streamable](RowSetFinishing::is_streamable) is
3668 /// `true`.
3669 ///
3670 /// # Panics
3671 ///
3672 /// Panics if the result is not streamable, that is it has an ORDER BY.
3673 pub fn new(
3674 offset: usize,
3675 limit: Option<NonNeg<i64>>,
3676 project: Vec<usize>,
3677 max_returned_query_size: Option<u64>,
3678 ) -> Self {
3679 let limit = limit.map(|l| {
3680 let l = u64::from(l);
3681 let l = usize::cast_from(l);
3682 l
3683 });
3684
3685 RowSetFinishingIncremental {
3686 remaining_limit: limit,
3687 remaining_offset: offset,
3688 max_returned_query_size,
3689 remaining_max_returned_query_size: max_returned_query_size,
3690 project,
3691 }
3692 }
3693
3694 /// Applies finishing actions to the given [`RowCollection`], and reports
3695 /// the total time it took to run.
3696 ///
3697 /// Returns a [`RowCollectionIter`] that contains all of the response
3698 /// data.
3699 pub fn finish_incremental(
3700 &mut self,
3701 rows: RowCollection,
3702 max_result_size: u64,
3703 duration_histogram: &Histogram,
3704 ) -> Result<RowCollectionIter, String> {
3705 let now = Instant::now();
3706 let result = self.finish_incremental_inner(rows, max_result_size);
3707 let duration = now.elapsed();
3708 duration_histogram.observe(duration.as_secs_f64());
3709
3710 result
3711 }
3712
3713 fn finish_incremental_inner(
3714 &mut self,
3715 rows: RowCollection,
3716 max_result_size: u64,
3717 ) -> Result<RowCollectionIter, String> {
3718 // How much additional memory is required to make a sorted view.
3719 let sorted_view_mem = rows.entries().saturating_mul(std::mem::size_of::<usize>());
3720 let required_memory = rows.byte_len().saturating_add(sorted_view_mem);
3721
3722 // Bail if creating the sorted view would require us to use too much memory.
3723 if required_memory > usize::cast_from(max_result_size) {
3724 let max_bytes = ByteSize::b(max_result_size);
3725 return Err(format!("total result exceeds max size of {max_bytes}",));
3726 }
3727
3728 let batch_num_rows = rows.count();
3729
3730 let sorted_view = rows;
3731 let mut iter = sorted_view
3732 .into_row_iter()
3733 .apply_offset(self.remaining_offset)
3734 .with_projection(self.project.clone());
3735
3736 if let Some(limit) = self.remaining_limit {
3737 iter = iter.with_limit(limit);
3738 };
3739
3740 self.remaining_offset = self.remaining_offset.saturating_sub(batch_num_rows);
3741 if let Some(remaining_limit) = self.remaining_limit.as_mut() {
3742 *remaining_limit -= iter.count();
3743 }
3744
3745 // TODO(parkmycar): Re-think how we can calculate the total response size without
3746 // having to iterate through the entire collection of Rows, while still
3747 // respecting the LIMIT, OFFSET, and projections.
3748 //
3749 // Note: It feels a bit bad always calculating the response size, but we almost
3750 // always need it to either check the `max_returned_query_size`, or for reporting
3751 // in the query history.
3752 let response_size: usize = iter.clone().map(|row| row.data().len()).sum();
3753
3754 // Bail if we would end up returning more data to the client than they can support.
3755 if let Some(max) = &mut self.remaining_max_returned_query_size {
3756 if let Some(remaining) = max.checked_sub(response_size.cast_into()) {
3757 *max = remaining;
3758 } else {
3759 let max_bytes = ByteSize::b(self.max_returned_query_size.expect("known to exist"));
3760 return Err(format!("total result exceeds max size of {max_bytes}"));
3761 }
3762 }
3763
3764 Ok(iter)
3765 }
3766}
3767
3768/// Compares two rows columnwise, using [compare_columns].
3769///
3770/// Compared to the naive implementation, this allows sharing some memory and implements some
3771/// optimizations that avoid unnecessary row unpacking.
3772#[derive(Debug, Clone)]
3773pub struct RowComparator<O: AsRef<[ColumnOrder]> = Vec<ColumnOrder>> {
3774 order: O,
3775 /// Invariant: all column references in the order are less than this limit.
3776 /// This allows for partial unpacking of rows.
3777 limit: usize,
3778 left_vec: RefCell<DatumVec>,
3779 right_vec: RefCell<DatumVec>,
3780}
3781
3782impl<O: AsRef<[ColumnOrder]>> RowComparator<O> {
3783 /// Create a new row comparator from the given column ordering.
3784 pub fn new(order: O) -> Self {
3785 let limit = order
3786 .as_ref()
3787 .iter()
3788 .map(|o| o.column + 1)
3789 .max()
3790 .unwrap_or(0);
3791 Self {
3792 order,
3793 limit,
3794 left_vec: Default::default(),
3795 right_vec: Default::default(),
3796 }
3797 }
3798
3799 /// Compare two (references to) rows.
3800 pub fn compare_rows(
3801 &self,
3802 left_row: &RowRef,
3803 right_row: &RowRef,
3804 tiebreaker: impl Fn() -> Ordering,
3805 ) -> Ordering {
3806 let order = if self.limit == 0 {
3807 Ordering::Equal
3808 } else {
3809 // These borrows should never fail, since this struct is non-sync and this function
3810 // is non-recursive.
3811 let mut left_ref = self.left_vec.borrow_mut();
3812 let mut right_ref = self.right_vec.borrow_mut();
3813 let left_cols = left_ref.borrow_with_limit(left_row, self.limit);
3814 let right_cols = right_ref.borrow_with_limit(right_row, self.limit);
3815 compare_columns(self.order.as_ref(), &left_cols, &right_cols, || {
3816 Ordering::Equal
3817 })
3818 };
3819 // Tiebreak without the vecs borrowed, in case that recursively invokes this function.
3820 order.then_with(tiebreaker)
3821 }
3822}
3823
3824/// Compare `left` and `right` using `order`. If that doesn't produce a strict
3825/// ordering, call `tiebreaker`.
3826pub fn compare_columns<F>(
3827 order: &[ColumnOrder],
3828 left: &[Datum],
3829 right: &[Datum],
3830 tiebreaker: F,
3831) -> Ordering
3832where
3833 F: Fn() -> Ordering,
3834{
3835 for order in order {
3836 let cmp = match (&left[order.column], &right[order.column]) {
3837 (Datum::Null, Datum::Null) => Ordering::Equal,
3838 (Datum::Null, _) => {
3839 if order.nulls_last {
3840 Ordering::Greater
3841 } else {
3842 Ordering::Less
3843 }
3844 }
3845 (_, Datum::Null) => {
3846 if order.nulls_last {
3847 Ordering::Less
3848 } else {
3849 Ordering::Greater
3850 }
3851 }
3852 (lval, rval) => {
3853 if order.desc {
3854 rval.cmp(lval)
3855 } else {
3856 lval.cmp(rval)
3857 }
3858 }
3859 };
3860 if cmp != Ordering::Equal {
3861 return cmp;
3862 }
3863 }
3864 tiebreaker()
3865}
3866
3867/// Describe a window frame, e.g. `RANGE UNBOUNDED PRECEDING` or
3868/// `ROWS BETWEEN 5 PRECEDING AND CURRENT ROW`.
3869///
3870/// Window frames define a subset of the partition , and only a subset of
3871/// window functions make use of the window frame.
3872#[derive(
3873 Debug,
3874 Clone,
3875 Eq,
3876 PartialEq,
3877 Ord,
3878 PartialOrd,
3879 Serialize,
3880 Deserialize,
3881 Hash,
3882 MzReflect
3883)]
3884pub struct WindowFrame {
3885 /// ROWS, RANGE or GROUPS
3886 pub units: WindowFrameUnits,
3887 /// Where the frame starts
3888 pub start_bound: WindowFrameBound,
3889 /// Where the frame ends
3890 pub end_bound: WindowFrameBound,
3891}
3892
3893impl Display for WindowFrame {
3894 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
3895 write!(
3896 f,
3897 "{} between {} and {}",
3898 self.units, self.start_bound, self.end_bound
3899 )
3900 }
3901}
3902
3903impl WindowFrame {
3904 /// Return the default window frame used when one is not explicitly defined
3905 pub fn default() -> Self {
3906 WindowFrame {
3907 units: WindowFrameUnits::Range,
3908 start_bound: WindowFrameBound::UnboundedPreceding,
3909 end_bound: WindowFrameBound::CurrentRow,
3910 }
3911 }
3912
3913 fn includes_current_row(&self) -> bool {
3914 use WindowFrameBound::*;
3915 match self.start_bound {
3916 UnboundedPreceding => match self.end_bound {
3917 UnboundedPreceding => false,
3918 OffsetPreceding(0) => true,
3919 OffsetPreceding(_) => false,
3920 CurrentRow => true,
3921 OffsetFollowing(_) => true,
3922 UnboundedFollowing => true,
3923 },
3924 OffsetPreceding(0) => match self.end_bound {
3925 UnboundedPreceding => unreachable!(),
3926 OffsetPreceding(0) => true,
3927 // Any nonzero offsets here will create an empty window
3928 OffsetPreceding(_) => false,
3929 CurrentRow => true,
3930 OffsetFollowing(_) => true,
3931 UnboundedFollowing => true,
3932 },
3933 OffsetPreceding(_) => match self.end_bound {
3934 UnboundedPreceding => unreachable!(),
3935 // Window ends at the current row
3936 OffsetPreceding(0) => true,
3937 OffsetPreceding(_) => false,
3938 CurrentRow => true,
3939 OffsetFollowing(_) => true,
3940 UnboundedFollowing => true,
3941 },
3942 CurrentRow => true,
3943 OffsetFollowing(0) => match self.end_bound {
3944 UnboundedPreceding => unreachable!(),
3945 OffsetPreceding(_) => unreachable!(),
3946 CurrentRow => unreachable!(),
3947 OffsetFollowing(_) => true,
3948 UnboundedFollowing => true,
3949 },
3950 OffsetFollowing(_) => match self.end_bound {
3951 UnboundedPreceding => unreachable!(),
3952 OffsetPreceding(_) => unreachable!(),
3953 CurrentRow => unreachable!(),
3954 OffsetFollowing(_) => false,
3955 UnboundedFollowing => false,
3956 },
3957 UnboundedFollowing => false,
3958 }
3959 }
3960}
3961
3962/// Describe how frame bounds are interpreted
3963#[derive(
3964 Debug,
3965 Clone,
3966 Eq,
3967 PartialEq,
3968 Ord,
3969 PartialOrd,
3970 Serialize,
3971 Deserialize,
3972 Hash,
3973 MzReflect
3974)]
3975pub enum WindowFrameUnits {
3976 /// Each row is treated as the unit of work for bounds
3977 Rows,
3978 /// Each peer group is treated as the unit of work for bounds,
3979 /// and offset-based bounds use the value of the ORDER BY expression
3980 Range,
3981 /// Each peer group is treated as the unit of work for bounds.
3982 /// Groups is currently not supported, and it is rejected during planning.
3983 Groups,
3984}
3985
3986impl Display for WindowFrameUnits {
3987 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
3988 match self {
3989 WindowFrameUnits::Rows => write!(f, "rows"),
3990 WindowFrameUnits::Range => write!(f, "range"),
3991 WindowFrameUnits::Groups => write!(f, "groups"),
3992 }
3993 }
3994}
3995
3996/// Specifies [WindowFrame]'s `start_bound` and `end_bound`
3997///
3998/// The order between frame bounds is significant, as Postgres enforces
3999/// some restrictions there.
4000#[derive(
4001 Debug,
4002 Clone,
4003 Serialize,
4004 Deserialize,
4005 PartialEq,
4006 Eq,
4007 Hash,
4008 MzReflect,
4009 PartialOrd,
4010 Ord
4011)]
4012pub enum WindowFrameBound {
4013 /// `UNBOUNDED PRECEDING`
4014 UnboundedPreceding,
4015 /// `<N> PRECEDING`
4016 OffsetPreceding(u64),
4017 /// `CURRENT ROW`
4018 CurrentRow,
4019 /// `<N> FOLLOWING`
4020 OffsetFollowing(u64),
4021 /// `UNBOUNDED FOLLOWING`.
4022 UnboundedFollowing,
4023}
4024
4025impl Display for WindowFrameBound {
4026 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
4027 match self {
4028 WindowFrameBound::UnboundedPreceding => write!(f, "unbounded preceding"),
4029 WindowFrameBound::OffsetPreceding(offset) => write!(f, "{} preceding", offset),
4030 WindowFrameBound::CurrentRow => write!(f, "current row"),
4031 WindowFrameBound::OffsetFollowing(offset) => write!(f, "{} following", offset),
4032 WindowFrameBound::UnboundedFollowing => write!(f, "unbounded following"),
4033 }
4034 }
4035}
4036
4037/// Maximum iterations for a LetRec.
4038#[derive(
4039 Debug,
4040 Clone,
4041 Copy,
4042 PartialEq,
4043 Eq,
4044 PartialOrd,
4045 Ord,
4046 Hash,
4047 Serialize,
4048 Deserialize
4049)]
4050pub struct LetRecLimit {
4051 /// Maximum number of iterations to evaluate.
4052 pub max_iters: NonZeroU64,
4053 /// Whether to throw an error when reaching the above limit.
4054 /// If true, we simply use the current contents of each Id as the final result.
4055 pub return_at_limit: bool,
4056}
4057
4058impl LetRecLimit {
4059 /// Compute the smallest limit from a Vec of `LetRecLimit`s.
4060 pub fn min_max_iter(limits: &Vec<Option<LetRecLimit>>) -> Option<u64> {
4061 limits
4062 .iter()
4063 .filter_map(|l| l.as_ref().map(|l| l.max_iters.get()))
4064 .min()
4065 }
4066
4067 /// The default value of `LetRecLimit::return_at_limit` when using the RECURSION LIMIT option of
4068 /// WMR without ERROR AT or RETURN AT.
4069 pub const RETURN_AT_LIMIT_DEFAULT: bool = false;
4070}
4071
4072impl Display for LetRecLimit {
4073 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
4074 write!(f, "[recursion_limit={}", self.max_iters)?;
4075 if self.return_at_limit != LetRecLimit::RETURN_AT_LIMIT_DEFAULT {
4076 write!(f, ", return_at_limit")?;
4077 }
4078 write!(f, "]")
4079 }
4080}
4081
4082/// For a global Get, this indicates whether we are going to read from Persist or from an index.
4083/// (See comment in MirRelationExpr::Get.)
4084#[derive(
4085 Clone,
4086 Debug,
4087 Eq,
4088 PartialEq,
4089 Ord,
4090 PartialOrd,
4091 Serialize,
4092 Deserialize,
4093 Hash
4094)]
4095pub enum AccessStrategy {
4096 /// It's either a local Get (a CTE), or unknown at the time.
4097 /// `prune_and_annotate_dataflow_index_imports` decides it for global Gets, and thus switches to
4098 /// one of the other variants.
4099 UnknownOrLocal,
4100 /// The Get will read from Persist.
4101 Persist,
4102 /// The Get will read from an index or indexes: (index id, how the index will be used).
4103 Index(Vec<(GlobalId, IndexUsageType)>),
4104 /// The Get will read a collection that is computed by the same dataflow, but in a different
4105 /// `BuildDesc` in `objects_to_build`.
4106 SameDataflow,
4107}
4108
4109#[cfg(test)]
4110mod tests {
4111 use std::num::NonZeroUsize;
4112
4113 use mz_repr::explain::text::text_string_at;
4114
4115 use crate::explain::HumanizedExplain;
4116
4117 use super::*;
4118
4119 #[mz_ore::test]
4120 fn test_row_set_finishing_as_text() {
4121 let finishing = RowSetFinishing {
4122 order_by: vec![ColumnOrder {
4123 column: 4,
4124 desc: true,
4125 nulls_last: true,
4126 }],
4127 limit: Some(NonNeg::try_from(7).unwrap()),
4128 offset: Default::default(),
4129 project: vec![1, 3, 4, 5],
4130 };
4131
4132 let mode = HumanizedExplain::new(false);
4133 let expr = mode.expr(&finishing, None);
4134
4135 let act = text_string_at(&expr, mz_ore::str::Indent::default);
4136
4137 let exp = {
4138 use mz_ore::fmt::FormatBuffer;
4139 let mut s = String::new();
4140 write!(&mut s, "Finish");
4141 write!(&mut s, " order_by=[#4 desc nulls_last]");
4142 write!(&mut s, " limit=7");
4143 write!(&mut s, " output=[#1, #3..=#5]");
4144 writeln!(&mut s, "");
4145 s
4146 };
4147
4148 assert_eq!(act, exp);
4149 }
4150
4151 #[mz_ore::test]
4152 fn test_row_set_finishing_incremental_max_returned_query_size() {
4153 let row = Row::pack_slice(&[Datum::String("hello")]);
4154 let row_size = u64::cast_from(row.data().len());
4155 let diff = NonZeroUsize::new(1).unwrap();
4156 let batch = RowCollection::new(vec![(row, diff)], &[]);
4157
4158 // Set max_returned_query_size to hold exactly 2 batches worth of rows.
4159 let mut finishing = RowSetFinishingIncremental::new(0, None, vec![0], Some(row_size * 2));
4160
4161 let max_result_size = u64::MAX;
4162
4163 let r = finishing.finish_incremental_inner(batch.clone(), max_result_size);
4164 assert!(r.is_ok());
4165 assert_eq!(finishing.remaining_max_returned_query_size, Some(row_size));
4166
4167 let r = finishing.finish_incremental_inner(batch.clone(), max_result_size);
4168 assert!(r.is_ok());
4169 assert_eq!(finishing.remaining_max_returned_query_size, Some(0));
4170
4171 let r = finishing.finish_incremental_inner(batch, max_result_size);
4172 assert!(r.unwrap_err().contains("total result exceeds max size"));
4173 }
4174}
4175
4176/// An iterator over AST structures, which calls out nodes in difference.
4177///
4178/// The iterators visit two ASTs in tandem, continuing as long as the AST node data matches,
4179/// and yielding an output pair as soon as the AST nodes do not match. Their intent is to call
4180/// attention to the moments in the ASTs where they differ, and incidentally a stack-free way
4181/// to compare two ASTs.
4182mod structured_diff {
4183
4184 use super::MirRelationExpr;
4185 use itertools::Itertools;
4186
4187 /// An iterator over structured differences between two `MirRelationExpr` instances.
4188 pub struct MreDiff<'a> {
4189 /// Pairs of expressions that must still be compared.
4190 todo: Vec<(&'a MirRelationExpr, &'a MirRelationExpr)>,
4191 }
4192
4193 impl<'a> MreDiff<'a> {
4194 /// Create a new `MirRelationExpr` structured difference.
4195 pub fn new(expr1: &'a MirRelationExpr, expr2: &'a MirRelationExpr) -> Self {
4196 MreDiff {
4197 todo: vec![(expr1, expr2)],
4198 }
4199 }
4200 }
4201
4202 impl<'a> Iterator for MreDiff<'a> {
4203 // Pairs of expressions that do not match.
4204 type Item = (&'a MirRelationExpr, &'a MirRelationExpr);
4205
4206 fn next(&mut self) -> Option<Self::Item> {
4207 while let Some((expr1, expr2)) = self.todo.pop() {
4208 match (expr1, expr2) {
4209 (
4210 MirRelationExpr::Constant {
4211 rows: rows1,
4212 typ: typ1,
4213 },
4214 MirRelationExpr::Constant {
4215 rows: rows2,
4216 typ: typ2,
4217 },
4218 ) => {
4219 if rows1 != rows2 || typ1 != typ2 {
4220 return Some((expr1, expr2));
4221 }
4222 }
4223 (
4224 MirRelationExpr::Get {
4225 id: id1,
4226 typ: typ1,
4227 access_strategy: as1,
4228 },
4229 MirRelationExpr::Get {
4230 id: id2,
4231 typ: typ2,
4232 access_strategy: as2,
4233 },
4234 ) => {
4235 if id1 != id2 || typ1 != typ2 || as1 != as2 {
4236 return Some((expr1, expr2));
4237 }
4238 }
4239 (
4240 MirRelationExpr::Let {
4241 id: id1,
4242 body: body1,
4243 value: value1,
4244 },
4245 MirRelationExpr::Let {
4246 id: id2,
4247 body: body2,
4248 value: value2,
4249 },
4250 ) => {
4251 if id1 != id2 {
4252 return Some((expr1, expr2));
4253 } else {
4254 self.todo.push((body1, body2));
4255 self.todo.push((value1, value2));
4256 }
4257 }
4258 (
4259 MirRelationExpr::LetRec {
4260 ids: ids1,
4261 body: body1,
4262 values: values1,
4263 limits: limits1,
4264 },
4265 MirRelationExpr::LetRec {
4266 ids: ids2,
4267 body: body2,
4268 values: values2,
4269 limits: limits2,
4270 },
4271 ) => {
4272 if ids1 != ids2 || values1.len() != values2.len() || limits1 != limits2 {
4273 return Some((expr1, expr2));
4274 } else {
4275 self.todo.push((body1, body2));
4276 self.todo.extend(values1.iter().zip_eq(values2.iter()));
4277 }
4278 }
4279 (
4280 MirRelationExpr::Project {
4281 outputs: outputs1,
4282 input: input1,
4283 },
4284 MirRelationExpr::Project {
4285 outputs: outputs2,
4286 input: input2,
4287 },
4288 ) => {
4289 if outputs1 != outputs2 {
4290 return Some((expr1, expr2));
4291 } else {
4292 self.todo.push((input1, input2));
4293 }
4294 }
4295 (
4296 MirRelationExpr::Map {
4297 scalars: scalars1,
4298 input: input1,
4299 },
4300 MirRelationExpr::Map {
4301 scalars: scalars2,
4302 input: input2,
4303 },
4304 ) => {
4305 if scalars1 != scalars2 {
4306 return Some((expr1, expr2));
4307 } else {
4308 self.todo.push((input1, input2));
4309 }
4310 }
4311 (
4312 MirRelationExpr::Filter {
4313 predicates: predicates1,
4314 input: input1,
4315 },
4316 MirRelationExpr::Filter {
4317 predicates: predicates2,
4318 input: input2,
4319 },
4320 ) => {
4321 if predicates1 != predicates2 {
4322 return Some((expr1, expr2));
4323 } else {
4324 self.todo.push((input1, input2));
4325 }
4326 }
4327 (
4328 MirRelationExpr::FlatMap {
4329 input: input1,
4330 func: func1,
4331 exprs: exprs1,
4332 },
4333 MirRelationExpr::FlatMap {
4334 input: input2,
4335 func: func2,
4336 exprs: exprs2,
4337 },
4338 ) => {
4339 if func1 != func2 || exprs1 != exprs2 {
4340 return Some((expr1, expr2));
4341 } else {
4342 self.todo.push((input1, input2));
4343 }
4344 }
4345 (
4346 MirRelationExpr::Join {
4347 inputs: inputs1,
4348 equivalences: eq1,
4349 implementation: impl1,
4350 },
4351 MirRelationExpr::Join {
4352 inputs: inputs2,
4353 equivalences: eq2,
4354 implementation: impl2,
4355 },
4356 ) => {
4357 if inputs1.len() != inputs2.len() || eq1 != eq2 || impl1 != impl2 {
4358 return Some((expr1, expr2));
4359 } else {
4360 self.todo.extend(inputs1.iter().zip_eq(inputs2.iter()));
4361 }
4362 }
4363 (
4364 MirRelationExpr::Reduce {
4365 aggregates: aggregates1,
4366 input: inputs1,
4367 group_key: gk1,
4368 monotonic: m1,
4369 expected_group_size: egs1,
4370 },
4371 MirRelationExpr::Reduce {
4372 aggregates: aggregates2,
4373 input: inputs2,
4374 group_key: gk2,
4375 monotonic: m2,
4376 expected_group_size: egs2,
4377 },
4378 ) => {
4379 if aggregates1 != aggregates2 || gk1 != gk2 || m1 != m2 || egs1 != egs2 {
4380 return Some((expr1, expr2));
4381 } else {
4382 self.todo.push((inputs1, inputs2));
4383 }
4384 }
4385 (
4386 MirRelationExpr::TopK {
4387 group_key: gk1,
4388 order_key: order1,
4389 input: input1,
4390 limit: l1,
4391 offset: o1,
4392 monotonic: m1,
4393 expected_group_size: egs1,
4394 },
4395 MirRelationExpr::TopK {
4396 group_key: gk2,
4397 order_key: order2,
4398 input: input2,
4399 limit: l2,
4400 offset: o2,
4401 monotonic: m2,
4402 expected_group_size: egs2,
4403 },
4404 ) => {
4405 if order1 != order2
4406 || gk1 != gk2
4407 || l1 != l2
4408 || o1 != o2
4409 || m1 != m2
4410 || egs1 != egs2
4411 {
4412 return Some((expr1, expr2));
4413 } else {
4414 self.todo.push((input1, input2));
4415 }
4416 }
4417 (
4418 MirRelationExpr::Negate { input: input1 },
4419 MirRelationExpr::Negate { input: input2 },
4420 ) => {
4421 self.todo.push((input1, input2));
4422 }
4423 (
4424 MirRelationExpr::Threshold { input: input1 },
4425 MirRelationExpr::Threshold { input: input2 },
4426 ) => {
4427 self.todo.push((input1, input2));
4428 }
4429 (
4430 MirRelationExpr::Union {
4431 base: base1,
4432 inputs: inputs1,
4433 },
4434 MirRelationExpr::Union {
4435 base: base2,
4436 inputs: inputs2,
4437 },
4438 ) => {
4439 if inputs1.len() != inputs2.len() {
4440 return Some((expr1, expr2));
4441 } else {
4442 self.todo.push((base1, base2));
4443 self.todo.extend(inputs1.iter().zip_eq(inputs2.iter()));
4444 }
4445 }
4446 (
4447 MirRelationExpr::ArrangeBy {
4448 keys: keys1,
4449 input: input1,
4450 },
4451 MirRelationExpr::ArrangeBy {
4452 keys: keys2,
4453 input: input2,
4454 },
4455 ) => {
4456 if keys1 != keys2 {
4457 return Some((expr1, expr2));
4458 } else {
4459 self.todo.push((input1, input2));
4460 }
4461 }
4462 _ => {
4463 return Some((expr1, expr2));
4464 }
4465 }
4466 }
4467 None
4468 }
4469 }
4470}