mz_expr/relation.rs
1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10#![warn(missing_docs)]
11
12use std::cell::RefCell;
13use std::cmp::{Ordering, max};
14use std::collections::{BTreeMap, BTreeSet};
15use std::fmt;
16use std::fmt::{Display, Formatter};
17use std::hash::{DefaultHasher, Hash, Hasher};
18use std::num::NonZeroU64;
19use std::time::Instant;
20
21use bytesize::ByteSize;
22use columnation::{Columnation, CopyRegion};
23use itertools::Itertools;
24use mz_lowertest::MzReflect;
25use mz_ore::cast::{CastFrom, CastInto};
26use mz_ore::collections::CollectionExt;
27use mz_ore::id_gen::IdGen;
28use mz_ore::metrics::Histogram;
29use mz_ore::num::NonNeg;
30use mz_ore::soft_assert_no_log;
31use mz_ore::stack::RecursionLimitError;
32use mz_ore::str::Indent;
33use mz_repr::adt::numeric::NumericMaxScale;
34use mz_repr::explain::text::text_string_at;
35use mz_repr::explain::{
36 DummyHumanizer, ExplainConfig, ExprHumanizer, IndexUsageType, PlanRenderingContext,
37};
38use mz_repr::{
39 ColumnName, Datum, DatumVec, Diff, GlobalId, IntoRowIterator, ReprColumnType, ReprRelationType,
40 ReprScalarType, Row, RowIterator, RowRef, SqlColumnType, SqlRelationType, SqlScalarType,
41};
42use serde::{Deserialize, Serialize};
43
44use crate::Id::Local;
45use crate::explain::{HumanizedExpr, HumanizerMode};
46use crate::relation::func::{AggregateFunc, LagLeadType, TableFunc};
47use crate::row::{RowCollection, RowCollectionIter};
48use crate::scalar::columns::Columns;
49use crate::scalar::func::variadic::{
50 JsonbBuildArray, JsonbBuildObject, ListCreate, ListIndex, MapBuild, RecordCreate,
51};
52use crate::visit::{Visit, VisitChildren};
53use crate::{
54 EvalError, FilterCharacteristics, Id, LocalId, MirScalarExpr, UnaryFunc, func as scalar_func,
55};
56
57pub mod canonicalize;
58pub mod func;
59pub mod join_input_mapper;
60
61/// A recursion limit to be used for stack-safe traversals of [`MirRelationExpr`] trees.
62///
63/// The recursion limit must be large enough to accommodate for the linear representation
64/// of some pathological but frequently occurring query fragments.
65///
66/// For example, in MIR we could have long chains of
67/// - (1) `Let` bindings,
68/// - (2) `CallBinary` calls with associative functions such as `+`
69///
70/// Until we fix those, we need to stick with the larger recursion limit.
71pub const RECURSION_LIMIT: usize = 2048;
72
73/// A trait for types that describe how to build a collection.
74pub trait CollectionPlan {
75 /// Collects the set of global identifiers from dataflows referenced in Get.
76 fn depends_on_into(&self, out: &mut BTreeSet<GlobalId>);
77
78 /// Returns the set of global identifiers from dataflows referenced in Get.
79 ///
80 /// See [`CollectionPlan::depends_on_into`] to reuse an existing `BTreeSet`.
81 fn depends_on(&self) -> BTreeSet<GlobalId> {
82 let mut out = BTreeSet::new();
83 self.depends_on_into(&mut out);
84 out
85 }
86}
87
88/// An abstract syntax tree which defines a collection.
89///
90/// The AST is meant to reflect the capabilities of the `differential_dataflow::Collection` type,
91/// written generically enough to avoid run-time compilation work.
92///
93/// `derived_hash_with_manual_eq` was complaining for the wrong reason: This lint exists because
94/// it's bad when `Eq` doesn't agree with `Hash`, which is often quite likely if one of them is
95/// implemented manually. However, our manual implementation of `Eq` _will_ agree with the derived
96/// one. This is because the reason for the manual implementation is not to change the semantics
97/// from the derived one, but to avoid stack overflows.
98#[allow(clippy::derived_hash_with_manual_eq)]
99#[derive(Clone, Debug, Ord, PartialOrd, Serialize, Deserialize, MzReflect, Hash)]
100pub enum MirRelationExpr {
101 /// A constant relation containing specified rows.
102 ///
103 /// The runtime memory footprint of this operator is zero.
104 ///
105 /// When you would like to pattern match on this, consider using `MirRelationExpr::as_const`
106 /// instead, which looks behind `ArrangeBy`s. You might want this matching behavior because
107 /// constant folding doesn't remove `ArrangeBy`s.
108 Constant {
109 /// Rows of the constant collection and their multiplicities.
110 rows: Result<Vec<(Row, Diff)>, EvalError>,
111 /// Schema of the collection.
112 typ: ReprRelationType,
113 },
114 /// Get an existing dataflow.
115 ///
116 /// The runtime memory footprint of this operator is zero.
117 Get {
118 /// The identifier for the collection to load.
119 #[mzreflect(ignore)]
120 id: Id,
121 /// Schema of the collection.
122 typ: ReprRelationType,
123 /// If this is a global Get, this will indicate whether we are going to read from Persist or
124 /// from an index, or from a different object in `objects_to_build`. If it's an index, then
125 /// how downstream dataflow operations will use this index is also recorded. This is filled
126 /// by `prune_and_annotate_dataflow_index_imports`. Note that this is not used by the
127 /// lowering to LIR, but is used only by EXPLAIN.
128 #[mzreflect(ignore)]
129 access_strategy: AccessStrategy,
130 },
131 /// Introduce a temporary dataflow.
132 ///
133 /// The runtime memory footprint of this operator is zero.
134 Let {
135 /// The identifier to be used in `Get` variants to retrieve `value`.
136 #[mzreflect(ignore)]
137 id: LocalId,
138 /// The collection to be bound to `id`.
139 value: Box<MirRelationExpr>,
140 /// The result of the `Let`, evaluated with `id` bound to `value`.
141 body: Box<MirRelationExpr>,
142 },
143 /// Introduce mutually recursive bindings.
144 ///
145 /// Each `LocalId` is immediately bound to an initially empty collection
146 /// with the type of its corresponding `MirRelationExpr`. Repeatedly, each
147 /// binding is evaluated using the current contents of each other binding,
148 /// and is refreshed to contain the new evaluation. This process continues
149 /// through all bindings, and repeats as long as changes continue to occur.
150 ///
151 /// The resulting value of the expression is `body` evaluated once in the
152 /// context of the final iterates.
153 ///
154 /// A zero-binding instance can be replaced by `body`.
155 /// A single-binding instance is equivalent to `MirRelationExpr::Let`.
156 ///
157 /// The runtime memory footprint of this operator is zero.
158 LetRec {
159 /// The identifiers to be used in `Get` variants to retrieve each `value`.
160 #[mzreflect(ignore)]
161 ids: Vec<LocalId>,
162 /// The collections to be bound to each `id`.
163 values: Vec<MirRelationExpr>,
164 /// Maximum number of iterations, after which we should artificially force a fixpoint.
165 /// (Whether we error or just stop is configured by `LetRecLimit::return_at_limit`.)
166 /// The per-`LetRec` limit that the user specified is initially copied to each binding to
167 /// accommodate slicing and merging of `LetRec`s in MIR transforms (e.g., `NormalizeLets`).
168 #[mzreflect(ignore)]
169 limits: Vec<Option<LetRecLimit>>,
170 /// The result of the `Let`, evaluated with `id` bound to `value`.
171 body: Box<MirRelationExpr>,
172 },
173 /// Project out some columns from a dataflow
174 ///
175 /// The runtime memory footprint of this operator is zero.
176 Project {
177 /// The source collection.
178 input: Box<MirRelationExpr>,
179 /// Indices of columns to retain.
180 outputs: Vec<usize>,
181 },
182 /// Append new columns to a dataflow
183 ///
184 /// The runtime memory footprint of this operator is zero.
185 Map {
186 /// The source collection.
187 input: Box<MirRelationExpr>,
188 /// Expressions which determine values to append to each row.
189 /// An expression may refer to columns in `input` or
190 /// expressions defined earlier in the vector
191 scalars: Vec<MirScalarExpr>,
192 },
193 /// Like Map, but yields zero-or-more output rows per input row
194 ///
195 /// The runtime memory footprint of this operator is zero.
196 FlatMap {
197 /// The source collection
198 input: Box<MirRelationExpr>,
199 /// The table func to apply
200 func: TableFunc,
201 /// The argument to the table func
202 exprs: Vec<MirScalarExpr>,
203 },
204 /// Keep rows from a dataflow where all the predicates are true
205 ///
206 /// The runtime memory footprint of this operator is zero.
207 Filter {
208 /// The source collection.
209 input: Box<MirRelationExpr>,
210 /// Predicates, each of which must be true.
211 predicates: Vec<MirScalarExpr>,
212 },
213 /// Join several collections, where some columns must be equal.
214 ///
215 /// For further details consult the documentation for [`MirRelationExpr::join`].
216 ///
217 /// The runtime memory footprint of this operator can be proportional to
218 /// the sizes of all inputs and the size of all joins of prefixes.
219 /// This may be reduced due to arrangements available at rendering time.
220 Join {
221 /// A sequence of input relations.
222 inputs: Vec<MirRelationExpr>,
223 /// A sequence of equivalence classes of expressions on the cross product of inputs.
224 ///
225 /// Each equivalence class is a list of scalar expressions, where for each class the
226 /// intended interpretation is that all evaluated expressions should be equal.
227 ///
228 /// Each scalar expression is to be evaluated over the cross-product of all records
229 /// from all inputs. In many cases this may just be column selection from specific
230 /// inputs, but more general cases exist (e.g. complex functions of multiple columns
231 /// from multiple inputs, or just constant literals).
232 equivalences: Vec<Vec<MirScalarExpr>>,
233 /// Join implementation information.
234 #[serde(default)]
235 implementation: JoinImplementation,
236 },
237 /// Group a dataflow by some columns and aggregate over each group
238 ///
239 /// The runtime memory footprint of this operator is at most proportional to the
240 /// number of distinct records in the input and output. The actual requirements
241 /// can be less: the number of distinct inputs to each aggregate, summed across
242 /// each aggregate, plus the output size. For more details consult the code that
243 /// builds the associated dataflow.
244 Reduce {
245 /// The source collection.
246 input: Box<MirRelationExpr>,
247 /// Column indices used to form groups.
248 group_key: Vec<MirScalarExpr>,
249 /// Expressions which determine values to append to each row, after the group keys.
250 aggregates: Vec<AggregateExpr>,
251 /// True iff the input is known to monotonically increase (only addition of records).
252 #[serde(default)]
253 monotonic: bool,
254 /// User hint: expected number of values per group key. Used to optimize physical rendering.
255 #[serde(default)]
256 expected_group_size: Option<u64>,
257 },
258 /// Groups and orders within each group, limiting output.
259 ///
260 /// The runtime memory footprint of this operator is proportional to its input and output.
261 TopK {
262 /// The source collection.
263 input: Box<MirRelationExpr>,
264 /// Column indices used to form groups.
265 group_key: Vec<usize>,
266 /// Column indices used to order rows within groups.
267 order_key: Vec<ColumnOrder>,
268 /// Number of records to retain
269 #[serde(default)]
270 limit: Option<MirScalarExpr>,
271 /// Number of records to skip
272 #[serde(default)]
273 offset: usize,
274 /// True iff the input is known to monotonically increase (only addition of records).
275 #[serde(default)]
276 monotonic: bool,
277 /// User-supplied hint: how many rows will have the same group key.
278 #[serde(default)]
279 expected_group_size: Option<u64>,
280 },
281 /// Return a dataflow where the row counts are negated
282 ///
283 /// The runtime memory footprint of this operator is zero.
284 Negate {
285 /// The source collection.
286 input: Box<MirRelationExpr>,
287 },
288 /// Keep rows from a dataflow where the row counts are positive
289 ///
290 /// The runtime memory footprint of this operator is proportional to its input and output.
291 Threshold {
292 /// The source collection.
293 input: Box<MirRelationExpr>,
294 },
295 /// Adds the frequencies of elements in contained sets.
296 ///
297 /// The runtime memory footprint of this operator is zero.
298 Union {
299 /// A source collection.
300 base: Box<MirRelationExpr>,
301 /// Source collections to union.
302 inputs: Vec<MirRelationExpr>,
303 },
304 /// Technically a no-op. Used to render an index. Will be used to optimize queries
305 /// on finer grain. Each `keys` item represents a different index that should be
306 /// produced from the `keys`.
307 ///
308 /// The runtime memory footprint of this operator is proportional to its input.
309 ArrangeBy {
310 /// The source collection
311 input: Box<MirRelationExpr>,
312 /// Columns to arrange `input` by, in order of decreasing primacy
313 keys: Vec<Vec<MirScalarExpr>>,
314 },
315}
316
317impl PartialEq for MirRelationExpr {
318 fn eq(&self, other: &Self) -> bool {
319 // Capture the result and test it wrt `Ord` implementation in test environments.
320 let result = structured_diff::MreDiff::new(self, other).next().is_none();
321 mz_ore::soft_assert_eq_no_log!(result, self.cmp(other) == Ordering::Equal);
322 result
323 }
324}
325impl Eq for MirRelationExpr {}
326
327impl MirRelationExpr {
328 /// Reports the schema of the relation.
329 ///
330 /// This is the SQL-type parallel of [`Self::typ`]; it is merely
331 /// a wrapper around it, returning a [`SqlRelationType`] instead of
332 /// a [`ReprRelationType`].
333 pub fn sql_typ(&self) -> SqlRelationType {
334 let repr_typ = self.typ();
335 SqlRelationType::from_repr(&repr_typ)
336 }
337
338 /// Reports the repr schema of the relation.
339 ///
340 /// This method determines the type through recursive traversal of the
341 /// relation expression, drawing from the types of base collections.
342 /// As such, this is not an especially cheap method, and should be used
343 /// judiciously.
344 ///
345 /// The relation type is computed incrementally with a recursive post-order
346 /// traversal, that accumulates the input types for the relations yet to be
347 /// visited in `type_stack`.
348 pub fn typ(&self) -> ReprRelationType {
349 let mut type_stack = Vec::new();
350 #[allow(deprecated)]
351 self.visit_pre_post_nolimit(
352 &mut |e: &MirRelationExpr| -> Option<Vec<&MirRelationExpr>> {
353 match &e {
354 MirRelationExpr::Let { body, .. } => Some(vec![&*body]),
355 MirRelationExpr::LetRec { body, .. } => Some(vec![&*body]),
356 _ => None,
357 }
358 },
359 &mut |e: &MirRelationExpr| {
360 match e {
361 MirRelationExpr::Let { .. } => {
362 let body_typ = type_stack.pop().unwrap();
363 // Insert a dummy relation type for the value, since `typ_with_input_types`
364 // won't look at it, but expects the relation type of the body to be second.
365 type_stack.push(ReprRelationType::empty());
366 type_stack.push(body_typ);
367 }
368 MirRelationExpr::LetRec { values, .. } => {
369 let body_typ = type_stack.pop().unwrap();
370 type_stack.extend(
371 std::iter::repeat(ReprRelationType::empty()).take(values.len()),
372 );
373 // Insert dummy relation types for the values, since `typ_with_input_types`
374 // won't look at them, but expects the relation type of the body to be last.
375 type_stack.push(body_typ);
376 }
377 _ => {}
378 }
379 let num_inputs = e.num_inputs();
380 let relation_type =
381 e.typ_with_input_types(&type_stack[type_stack.len() - num_inputs..]);
382 type_stack.truncate(type_stack.len() - num_inputs);
383 type_stack.push(relation_type);
384 },
385 );
386 assert_eq!(type_stack.len(), 1);
387 type_stack.pop().unwrap()
388 }
389
390 /// Reports the repr schema of the relation given the repr schema of the input relations.
391 pub fn typ_with_input_types(&self, input_types: &[ReprRelationType]) -> ReprRelationType {
392 let column_types = self.col_with_input_cols(input_types.iter().map(|i| &i.column_types));
393 let unique_keys = self.keys_with_input_keys(
394 input_types.iter().map(|i| i.arity()),
395 input_types.iter().map(|i| &i.keys),
396 );
397 ReprRelationType::new(column_types).with_keys(unique_keys)
398 }
399
400 /// Reports the column types of the relation given the column types of the
401 /// input relations.
402 ///
403 /// This method delegates to `try_col_with_input_cols`, panicking if an `Err`
404 /// variant is returned.
405 pub fn col_with_input_cols<'a, I>(&self, input_types: I) -> Vec<ReprColumnType>
406 where
407 I: Iterator<Item = &'a Vec<ReprColumnType>>,
408 {
409 match self.try_col_with_input_cols(input_types) {
410 Ok(col_types) => col_types,
411 Err(err) => panic!("{err}"),
412 }
413 }
414
415 /// Reports the column types of the relation given the column types of the input relations.
416 ///
417 /// `input_types` is required to contain the column types for the input relations of
418 /// the current relation in the same order as they are visited by `try_visit_children`
419 /// method, even though not all may be used for computing the schema of the
420 /// current relation. For example, `Let` expects two input types, one for the
421 /// value relation and one for the body, in that order, but only the one for the
422 /// body is used to determine the type of the `Let` relation.
423 ///
424 /// It is meant to be used during post-order traversals to compute column types
425 /// incrementally.
426 pub fn try_col_with_input_cols<'a, I>(
427 &self,
428 mut input_types: I,
429 ) -> Result<Vec<ReprColumnType>, String>
430 where
431 I: Iterator<Item = &'a Vec<ReprColumnType>>,
432 {
433 use MirRelationExpr::*;
434
435 let col_types = match self {
436 Constant { rows, typ } => {
437 let mut col_types = typ.column_types.clone();
438 let mut seen_null = vec![false; typ.arity()];
439 if let Ok(rows) = rows {
440 for (row, _diff) in rows {
441 for (datum, i) in row.iter().zip_eq(0..typ.arity()) {
442 if datum.is_null() {
443 seen_null[i] = true;
444 }
445 }
446 }
447 }
448 for (&seen_null, i) in seen_null.iter().zip_eq(0..typ.arity()) {
449 if !seen_null {
450 col_types[i].nullable = false;
451 } else {
452 assert!(col_types[i].nullable);
453 }
454 }
455 col_types
456 }
457 Get { typ, .. } => typ.column_types.clone(),
458 Project { outputs, .. } => {
459 let input = input_types.next().unwrap();
460 outputs.iter().map(|&i| input[i].clone()).collect()
461 }
462 Map { scalars, .. } => {
463 let mut result = input_types.next().unwrap().clone();
464 for scalar in scalars.iter() {
465 result.push(scalar.typ(&result))
466 }
467 result
468 }
469 FlatMap { func, .. } => {
470 let mut result = input_types.next().unwrap().clone();
471 result.extend(
472 func.output_sql_type()
473 .column_types
474 .iter()
475 .map(ReprColumnType::from),
476 );
477 result
478 }
479 Filter { predicates, .. } => {
480 let mut result = input_types.next().unwrap().clone();
481
482 // Set as nonnull any columns where null values would cause
483 // any predicate to evaluate to null.
484 for column in non_nullable_columns(predicates) {
485 result[column].nullable = false;
486 }
487 result
488 }
489 Join { equivalences, .. } => {
490 // Concatenate input column types
491 let mut types = input_types.flat_map(|cols| cols.to_owned()).collect_vec();
492 // In an equivalence class, if any column is non-null, then make all non-null
493 for equivalence in equivalences {
494 let col_inds = equivalence
495 .iter()
496 .filter_map(|expr| match expr {
497 MirScalarExpr::Column(col, _name) => Some(*col),
498 _ => None,
499 })
500 .collect_vec();
501 if col_inds.iter().any(|i| !types.get(*i).unwrap().nullable) {
502 for i in col_inds {
503 types.get_mut(i).unwrap().nullable = false;
504 }
505 }
506 }
507 types
508 }
509 Reduce {
510 group_key,
511 aggregates,
512 ..
513 } => {
514 let input = input_types.next().unwrap();
515 group_key
516 .iter()
517 .map(|e| e.typ(input))
518 .chain(aggregates.iter().map(|agg| agg.typ(input)))
519 .collect()
520 }
521 TopK { .. } | Negate { .. } | Threshold { .. } | ArrangeBy { .. } => {
522 input_types.next().unwrap().clone()
523 }
524 Let { .. } => {
525 // skip over the input types for `value`.
526 input_types.nth(1).unwrap().clone()
527 }
528 LetRec { values, .. } => {
529 // skip over the input types for `values`.
530 input_types.nth(values.len()).unwrap().clone()
531 }
532 Union { .. } => {
533 let mut result = input_types.next().unwrap().clone();
534 for input_col_types in input_types {
535 for (base_col, col) in result.iter_mut().zip_eq(input_col_types) {
536 *base_col = base_col
537 .union(col)
538 .map_err(|e| format!("{}\nin plan:\n{}", e, self.pretty()))?;
539 }
540 }
541 result
542 }
543 };
544
545 Ok(col_types)
546 }
547
548 /// Reports the unique keys of the relation given the arities and the unique
549 /// keys of the input relations.
550 ///
551 /// `input_arities` and `input_keys` are required to contain the
552 /// corresponding info for the input relations of
553 /// the current relation in the same order as they are visited by `try_visit_children`
554 /// method, even though not all may be used for computing the schema of the
555 /// current relation. For example, `Let` expects two input types, one for the
556 /// value relation and one for the body, in that order, but only the one for the
557 /// body is used to determine the type of the `Let` relation.
558 ///
559 /// It is meant to be used during post-order traversals to compute unique keys
560 /// incrementally.
561 pub fn keys_with_input_keys<'a, I, J>(
562 &self,
563 mut input_arities: I,
564 mut input_keys: J,
565 ) -> Vec<Vec<usize>>
566 where
567 I: Iterator<Item = usize>,
568 J: Iterator<Item = &'a Vec<Vec<usize>>>,
569 {
570 use MirRelationExpr::*;
571
572 let mut keys = match self {
573 Constant {
574 rows: Ok(rows),
575 typ,
576 } => {
577 let n_cols = typ.arity();
578 // If the `i`th entry is `Some`, then we have not yet observed non-uniqueness in the `i`th column.
579 let mut unique_values_per_col = vec![Some(BTreeSet::<Datum>::default()); n_cols];
580 for (row, diff) in rows {
581 for (i, datum) in row.iter().enumerate() {
582 if datum != Datum::Dummy {
583 if let Some(unique_vals) = &mut unique_values_per_col[i] {
584 let is_dupe = *diff != Diff::ONE || !unique_vals.insert(datum);
585 if is_dupe {
586 unique_values_per_col[i] = None;
587 }
588 }
589 }
590 }
591 }
592 if rows.len() == 0 || (rows.len() == 1 && rows[0].1 == Diff::ONE) {
593 vec![vec![]]
594 } else {
595 // XXX - Multi-column keys are not detected.
596 typ.keys
597 .iter()
598 .cloned()
599 .chain(
600 unique_values_per_col
601 .into_iter()
602 .enumerate()
603 .filter(|(_idx, unique_vals)| unique_vals.is_some())
604 .map(|(idx, _)| vec![idx]),
605 )
606 .collect()
607 }
608 }
609 Constant { rows: Err(_), typ } | Get { typ, .. } => typ.keys.clone(),
610 Threshold { .. } | ArrangeBy { .. } => input_keys.next().unwrap().clone(),
611 Let { .. } => {
612 // skip over the unique keys for value
613 input_keys.nth(1).unwrap().clone()
614 }
615 LetRec { values, .. } => {
616 // skip over the unique keys for value
617 input_keys.nth(values.len()).unwrap().clone()
618 }
619 Project { outputs, .. } => {
620 let input = input_keys.next().unwrap();
621 input
622 .iter()
623 .filter_map(|key_set| {
624 if key_set.iter().all(|k| outputs.contains(k)) {
625 Some(
626 key_set
627 .iter()
628 .map(|c| outputs.iter().position(|o| o == c).unwrap())
629 .collect(),
630 )
631 } else {
632 None
633 }
634 })
635 .collect()
636 }
637 Map { scalars, .. } => {
638 let mut remappings = Vec::new();
639 let arity = input_arities.next().unwrap();
640 for (column, scalar) in scalars.iter().enumerate() {
641 // assess whether the scalar preserves uniqueness,
642 // and could participate in a key!
643
644 fn uniqueness(expr: &MirScalarExpr) -> Option<usize> {
645 match expr {
646 MirScalarExpr::CallUnary { func, expr } => {
647 if func.preserves_uniqueness() {
648 uniqueness(expr)
649 } else {
650 None
651 }
652 }
653 MirScalarExpr::Column(c, _name) => Some(*c),
654 _ => None,
655 }
656 }
657
658 if let Some(c) = uniqueness(scalar) {
659 remappings.push((c, column + arity));
660 }
661 }
662
663 let mut result = input_keys.next().unwrap().clone();
664 let mut new_keys = Vec::new();
665 // Any column in `remappings` could be replaced in a key
666 // by the corresponding c. This could lead to combinatorial
667 // explosion using our current representation, so we wont
668 // do that. Instead, we'll handle the case of one remapping.
669 if remappings.len() == 1 {
670 let (old, new) = remappings.pop().unwrap();
671 for key in &result {
672 if key.contains(&old) {
673 let mut new_key: Vec<usize> =
674 key.iter().cloned().filter(|k| k != &old).collect();
675 new_key.push(new);
676 new_key.sort_unstable();
677 new_keys.push(new_key);
678 }
679 }
680 result.append(&mut new_keys);
681 }
682 result
683 }
684 FlatMap { .. } => {
685 // FlatMap can add duplicate rows, so input keys are no longer
686 // valid
687 vec![]
688 }
689 Negate { .. } => {
690 // Although negate may have distinct records for each key,
691 // the multiplicity is -1 rather than 1. This breaks many
692 // of the optimization uses of "keys".
693 vec![]
694 }
695 Filter { predicates, .. } => {
696 // A filter inherits the keys of its input unless the filters
697 // have reduced the input to a single row, in which case the
698 // keys of the input are `()`.
699 let mut input = input_keys.next().unwrap().clone();
700
701 if !input.is_empty() {
702 // Track columns equated to literals, which we can prune.
703 let mut cols_equal_to_literal = BTreeSet::new();
704
705 // Perform union find on `col1 = col2` to establish
706 // connected components of equated columns. Absent any
707 // equalities, this will be `0 .. #c` (where #c is the
708 // greatest column referenced by a predicate), but each
709 // equality will orient the root of the greater to the root
710 // of the lesser.
711 let mut union_find = Vec::new();
712
713 for expr in predicates.iter() {
714 if let MirScalarExpr::CallBinary {
715 func: crate::BinaryFunc::Eq(_),
716 expr1,
717 expr2,
718 } = expr
719 {
720 if let MirScalarExpr::Column(c, _name) = &**expr1 {
721 if expr2.is_literal_ok() {
722 cols_equal_to_literal.insert(c);
723 }
724 }
725 if let MirScalarExpr::Column(c, _name) = &**expr2 {
726 if expr1.is_literal_ok() {
727 cols_equal_to_literal.insert(c);
728 }
729 }
730 // Perform union-find to equate columns.
731 if let (Some(c1), Some(c2)) = (expr1.as_column(), expr2.as_column()) {
732 if c1 != c2 {
733 // Ensure union_find has entries up to
734 // max(c1, c2) by filling up missing
735 // positions with identity mappings.
736 while union_find.len() <= std::cmp::max(c1, c2) {
737 union_find.push(union_find.len());
738 }
739 let mut r1 = c1; // Find the representative column of [c1].
740 while r1 != union_find[r1] {
741 assert!(union_find[r1] < r1);
742 r1 = union_find[r1];
743 }
744 let mut r2 = c2; // Find the representative column of [c2].
745 while r2 != union_find[r2] {
746 assert!(union_find[r2] < r2);
747 r2 = union_find[r2];
748 }
749 // Union [c1] and [c2] by pointing the
750 // larger to the smaller representative (we
751 // update the remaining equivalence class
752 // members only once after this for-loop).
753 union_find[std::cmp::max(r1, r2)] = std::cmp::min(r1, r2);
754 }
755 }
756 }
757 }
758
759 // Complete union-find by pointing each element at its representative column.
760 for i in 0..union_find.len() {
761 // Iteration not required, as each prior already references the right column.
762 union_find[i] = union_find[union_find[i]];
763 }
764
765 // Remove columns bound to literals, and remap columns equated to earlier columns.
766 // We will re-expand remapped columns in a moment, but this avoids exponential work.
767 for key_set in &mut input {
768 key_set.retain(|k| !cols_equal_to_literal.contains(&k));
769 for col in key_set.iter_mut() {
770 if let Some(equiv) = union_find.get(*col) {
771 *col = *equiv;
772 }
773 }
774 key_set.sort();
775 key_set.dedup();
776 }
777 input.sort();
778 input.dedup();
779
780 // Expand out each key to each of its equivalent forms.
781 // Each instance of `col` can be replaced by any equivalent column.
782 // This has the potential to result in exponentially sized number of unique keys,
783 // and in the future we should probably maintain unique keys modulo equivalence.
784
785 // First, compute an inverse map from each representative
786 // column `sub` to all other equivalent columns `col`.
787 let mut subs = Vec::new();
788 for (col, sub) in union_find.iter().enumerate() {
789 if *sub != col {
790 assert!(*sub < col);
791 while subs.len() <= *sub {
792 subs.push(Vec::new());
793 }
794 subs[*sub].push(col);
795 }
796 }
797 // For each column, substitute for it in each occurrence.
798 let mut to_add = Vec::new();
799 for (col, subs) in subs.iter().enumerate() {
800 if !subs.is_empty() {
801 for key_set in input.iter() {
802 if key_set.contains(&col) {
803 let mut to_extend = key_set.clone();
804 to_extend.retain(|c| c != &col);
805 for sub in subs {
806 to_extend.push(*sub);
807 to_add.push(to_extend.clone());
808 to_extend.pop();
809 }
810 }
811 }
812 }
813 // No deduplication, as we cannot introduce duplicates.
814 input.append(&mut to_add);
815 }
816 for key_set in input.iter_mut() {
817 key_set.sort();
818 key_set.dedup();
819 }
820 }
821 input
822 }
823 Join { equivalences, .. } => {
824 // It is important the `new_from_input_arities` constructor is
825 // used. Otherwise, Materialize may potentially end up in an
826 // infinite loop.
827 let input_mapper = crate::JoinInputMapper::new_from_input_arities(input_arities);
828
829 input_mapper.global_keys(input_keys, equivalences)
830 }
831 Reduce { group_key, .. } => {
832 // The group key should form a key, but we might already have
833 // keys that are subsets of the group key, and should retain
834 // those instead, if so.
835 let mut result = Vec::new();
836 for key_set in input_keys.next().unwrap() {
837 if key_set
838 .iter()
839 .all(|k| group_key.contains(&MirScalarExpr::column(*k)))
840 {
841 result.push(
842 key_set
843 .iter()
844 .map(|i| {
845 group_key
846 .iter()
847 .position(|k| k == &MirScalarExpr::column(*i))
848 .unwrap()
849 })
850 .collect::<Vec<_>>(),
851 );
852 }
853 }
854 if result.is_empty() {
855 result.push((0..group_key.len()).collect());
856 }
857 result
858 }
859 TopK {
860 group_key, limit, ..
861 } => {
862 // If `limit` is `Some(1)` then the group key will become
863 // a unique key, as there will be only one record with that key.
864 let mut result = input_keys.next().unwrap().clone();
865 if limit.as_ref().and_then(|x| x.as_literal_int64()) == Some(1) {
866 result.push(group_key.clone())
867 }
868 result
869 }
870 Union { base, inputs } => {
871 // Generally, unions do not have any unique keys, because
872 // each input might duplicate some. However, there is at
873 // least one idiomatic structure that does preserve keys,
874 // which results from SQL aggregations that must populate
875 // absent records with default values. In that pattern,
876 // the union of one GET with its negation, which has first
877 // been subjected to a projection and map, we can remove
878 // their influence on the key structure.
879 //
880 // If there are A, B, each with a unique `key` such that
881 // we are looking at
882 //
883 // A.proj(set_containing_key) + (B - A.proj(key)).map(stuff)
884 //
885 // Then we can report `key` as a unique key.
886 //
887 // TODO: make unique key structure an optimization analysis
888 // rather than part of the type information.
889 // TODO: perhaps ensure that (above) A.proj(key) is a
890 // subset of B, as otherwise there are negative records
891 // and who knows what is true (not expected, but again
892 // who knows what the query plan might look like).
893
894 let arity = input_arities.next().unwrap();
895 let (base_projection, base_with_project_stripped) =
896 if let MirRelationExpr::Project { input, outputs } = &**base {
897 (outputs.clone(), &**input)
898 } else {
899 // A input without a project is equivalent to an input
900 // with the project being all columns in the input in order.
901 ((0..arity).collect::<Vec<_>>(), &**base)
902 };
903 let mut result = Vec::new();
904 if let MirRelationExpr::Get {
905 id: first_id,
906 typ: _,
907 ..
908 } = base_with_project_stripped
909 {
910 if inputs.len() == 1 {
911 if let MirRelationExpr::Map { input, .. } = &inputs[0] {
912 if let MirRelationExpr::Union { base, inputs } = &**input {
913 if inputs.len() == 1 {
914 if let Some((input, outputs)) = base.is_negated_project() {
915 if let MirRelationExpr::Get {
916 id: second_id,
917 typ: _,
918 ..
919 } = input
920 {
921 if first_id == second_id {
922 result.extend(
923 input_keys
924 .next()
925 .unwrap()
926 .into_iter()
927 .filter(|key| {
928 key.iter().all(|c| {
929 outputs.get(*c) == Some(c)
930 && base_projection.get(*c)
931 == Some(c)
932 })
933 })
934 .cloned(),
935 );
936 }
937 }
938 }
939 }
940 }
941 }
942 }
943 }
944 // Important: do not inherit keys of either input, as not unique.
945 result
946 }
947 };
948 keys.sort();
949 keys.dedup();
950 keys
951 }
952
953 /// The number of columns in the relation.
954 ///
955 /// This number is determined from the type, which is determined recursively
956 /// at non-trivial cost.
957 ///
958 /// The arity is computed incrementally with a recursive post-order
959 /// traversal, that accumulates the arities for the relations yet to be
960 /// visited in `arity_stack`.
961 pub fn arity(&self) -> usize {
962 let mut arity_stack = Vec::new();
963 #[allow(deprecated)]
964 self.visit_pre_post_nolimit(
965 &mut |e: &MirRelationExpr| -> Option<Vec<&MirRelationExpr>> {
966 match &e {
967 MirRelationExpr::Let { body, .. } => {
968 // Do not traverse the value sub-graph, since it's not relevant for
969 // determining the arity of Let operators.
970 Some(vec![&*body])
971 }
972 MirRelationExpr::LetRec { body, .. } => {
973 // Do not traverse the value sub-graph, since it's not relevant for
974 // determining the arity of Let operators.
975 Some(vec![&*body])
976 }
977 MirRelationExpr::Project { .. } | MirRelationExpr::Reduce { .. } => {
978 // No further traversal is required; these operators know their arity.
979 Some(Vec::new())
980 }
981 _ => None,
982 }
983 },
984 &mut |e: &MirRelationExpr| {
985 match &e {
986 MirRelationExpr::Let { .. } => {
987 let body_arity = arity_stack.pop().unwrap();
988 arity_stack.push(0);
989 arity_stack.push(body_arity);
990 }
991 MirRelationExpr::LetRec { values, .. } => {
992 let body_arity = arity_stack.pop().unwrap();
993 arity_stack.extend(std::iter::repeat(0).take(values.len()));
994 arity_stack.push(body_arity);
995 }
996 MirRelationExpr::Project { .. } | MirRelationExpr::Reduce { .. } => {
997 arity_stack.push(0);
998 }
999 _ => {}
1000 }
1001 let num_inputs = e.num_inputs();
1002 let input_arities = arity_stack.drain(arity_stack.len() - num_inputs..);
1003 let arity = e.arity_with_input_arities(input_arities);
1004 arity_stack.push(arity);
1005 },
1006 );
1007 assert_eq!(arity_stack.len(), 1);
1008 arity_stack.pop().unwrap()
1009 }
1010
1011 /// Reports the arity of the relation given the schema of the input relations.
1012 ///
1013 /// `input_arities` is required to contain the arities for the input relations of
1014 /// the current relation in the same order as they are visited by `try_visit_children`
1015 /// method, even though not all may be used for computing the schema of the
1016 /// current relation. For example, `Let` expects two input types, one for the
1017 /// value relation and one for the body, in that order, but only the one for the
1018 /// body is used to determine the type of the `Let` relation.
1019 ///
1020 /// It is meant to be used during post-order traversals to compute arities
1021 /// incrementally.
1022 pub fn arity_with_input_arities<I>(&self, mut input_arities: I) -> usize
1023 where
1024 I: Iterator<Item = usize>,
1025 {
1026 use MirRelationExpr::*;
1027
1028 match self {
1029 Constant { rows: _, typ } => typ.arity(),
1030 Get { typ, .. } => typ.arity(),
1031 Let { .. } => {
1032 input_arities.next();
1033 input_arities.next().unwrap()
1034 }
1035 LetRec { values, .. } => {
1036 for _ in 0..values.len() {
1037 input_arities.next();
1038 }
1039 input_arities.next().unwrap()
1040 }
1041 Project { outputs, .. } => outputs.len(),
1042 Map { scalars, .. } => input_arities.next().unwrap() + scalars.len(),
1043 FlatMap { func, .. } => input_arities.next().unwrap() + func.output_arity(),
1044 Join { .. } => input_arities.sum(),
1045 Reduce {
1046 input: _,
1047 group_key,
1048 aggregates,
1049 ..
1050 } => group_key.len() + aggregates.len(),
1051 Filter { .. }
1052 | TopK { .. }
1053 | Negate { .. }
1054 | Threshold { .. }
1055 | Union { .. }
1056 | ArrangeBy { .. } => input_arities.next().unwrap(),
1057 }
1058 }
1059
1060 /// The number of child relations this relation has.
1061 pub fn num_inputs(&self) -> usize {
1062 let mut count = 0;
1063
1064 self.visit_children(|_| count += 1);
1065
1066 count
1067 }
1068
1069 /// Constructs a constant collection from specific rows and schema, where
1070 /// each row will have a multiplicity of one.
1071 pub fn constant(rows: Vec<Vec<Datum>>, typ: ReprRelationType) -> Self {
1072 let rows = rows.into_iter().map(|row| (row, Diff::ONE)).collect();
1073 MirRelationExpr::constant_diff(rows, typ)
1074 }
1075
1076 /// Constructs a constant collection from specific rows and schema, where
1077 /// each row can have an arbitrary multiplicity.
1078 pub fn constant_diff(rows: Vec<(Vec<Datum>, Diff)>, typ: ReprRelationType) -> Self {
1079 for (row, _diff) in &rows {
1080 for (datum, column_typ) in row.iter().zip_eq(typ.column_types.iter()) {
1081 assert!(
1082 datum.is_instance_of(column_typ),
1083 "Expected datum of type {:?}, got value {:?}",
1084 column_typ,
1085 datum
1086 );
1087 }
1088 }
1089 let rows = Ok(rows
1090 .into_iter()
1091 .map(move |(row, diff)| (Row::pack_slice(&row), diff))
1092 .collect());
1093 MirRelationExpr::Constant { rows, typ }
1094 }
1095
1096 /// If self is a constant, return the value and the type, otherwise `None`.
1097 /// Looks behind `ArrangeBy`s.
1098 pub fn as_const(&self) -> Option<(&Result<Vec<(Row, Diff)>, EvalError>, &ReprRelationType)> {
1099 match self {
1100 MirRelationExpr::Constant { rows, typ } => Some((rows, typ)),
1101 MirRelationExpr::ArrangeBy { input, .. } => input.as_const(),
1102 _ => None,
1103 }
1104 }
1105
1106 /// If self is a constant, mutably return the value and the type, otherwise `None`.
1107 /// Looks behind `ArrangeBy`s.
1108 pub fn as_const_mut(
1109 &mut self,
1110 ) -> Option<(
1111 &mut Result<Vec<(Row, Diff)>, EvalError>,
1112 &mut ReprRelationType,
1113 )> {
1114 match self {
1115 MirRelationExpr::Constant { rows, typ } => Some((rows, typ)),
1116 MirRelationExpr::ArrangeBy { input, .. } => input.as_const_mut(),
1117 _ => None,
1118 }
1119 }
1120
1121 /// If self is a constant error, return the error, otherwise `None`.
1122 /// Looks behind `ArrangeBy`s.
1123 pub fn as_const_err(&self) -> Option<&EvalError> {
1124 match self {
1125 MirRelationExpr::Constant { rows: Err(e), .. } => Some(e),
1126 MirRelationExpr::ArrangeBy { input, .. } => input.as_const_err(),
1127 _ => None,
1128 }
1129 }
1130
1131 /// Checks if `self` is the single element collection with no columns.
1132 pub fn is_constant_singleton(&self) -> bool {
1133 if let Some((Ok(rows), typ)) = self.as_const() {
1134 rows.len() == 1 && typ.column_types.len() == 0 && rows[0].1 == Diff::ONE
1135 } else {
1136 false
1137 }
1138 }
1139
1140 /// Constructs the expression for getting a local collection.
1141 pub fn local_get(id: LocalId, typ: ReprRelationType) -> Self {
1142 MirRelationExpr::Get {
1143 id: Id::Local(id),
1144 typ,
1145 access_strategy: AccessStrategy::UnknownOrLocal,
1146 }
1147 }
1148
1149 /// Constructs the expression for getting a global collection
1150 pub fn global_get(id: GlobalId, typ: ReprRelationType) -> Self {
1151 MirRelationExpr::Get {
1152 id: Id::Global(id),
1153 typ,
1154 access_strategy: AccessStrategy::UnknownOrLocal,
1155 }
1156 }
1157
1158 /// Retains only the columns specified by `output`.
1159 pub fn project(mut self, mut outputs: Vec<usize>) -> Self {
1160 if let MirRelationExpr::Project {
1161 outputs: columns, ..
1162 } = &mut self
1163 {
1164 // Update `outputs` to reference base columns of `input`.
1165 for column in outputs.iter_mut() {
1166 *column = columns[*column];
1167 }
1168 *columns = outputs;
1169 self
1170 } else {
1171 MirRelationExpr::Project {
1172 input: Box::new(self),
1173 outputs,
1174 }
1175 }
1176 }
1177
1178 /// Append to each row the results of applying elements of `scalar`.
1179 pub fn map(mut self, scalars: Vec<MirScalarExpr>) -> Self {
1180 if let MirRelationExpr::Map { scalars: s, .. } = &mut self {
1181 s.extend(scalars);
1182 self
1183 } else if !scalars.is_empty() {
1184 MirRelationExpr::Map {
1185 input: Box::new(self),
1186 scalars,
1187 }
1188 } else {
1189 self
1190 }
1191 }
1192
1193 /// Append to each row a single `scalar`.
1194 pub fn map_one(self, scalar: MirScalarExpr) -> Self {
1195 self.map(vec![scalar])
1196 }
1197
1198 /// Like `map`, but yields zero-or-more output rows per input row
1199 pub fn flat_map(self, func: TableFunc, exprs: Vec<MirScalarExpr>) -> Self {
1200 MirRelationExpr::FlatMap {
1201 input: Box::new(self),
1202 func,
1203 exprs,
1204 }
1205 }
1206
1207 /// Retain only the rows satisfying each of several predicates.
1208 pub fn filter<I>(mut self, predicates: I) -> Self
1209 where
1210 I: IntoIterator<Item = MirScalarExpr>,
1211 {
1212 // Extract existing predicates
1213 let mut new_predicates = if let MirRelationExpr::Filter { input, predicates } = self {
1214 self = *input;
1215 predicates
1216 } else {
1217 Vec::new()
1218 };
1219 // Normalize collection of predicates.
1220 new_predicates.extend(predicates);
1221 new_predicates.retain(|p| !p.is_literal_true());
1222 new_predicates.sort();
1223 new_predicates.dedup();
1224 // Introduce a `Filter` only if we have predicates.
1225 if !new_predicates.is_empty() {
1226 self = MirRelationExpr::Filter {
1227 input: Box::new(self),
1228 predicates: new_predicates,
1229 };
1230 }
1231
1232 self
1233 }
1234
1235 /// Form the Cartesian outer-product of rows in both inputs.
1236 pub fn product(mut self, right: Self) -> Self {
1237 if right.is_constant_singleton() {
1238 self
1239 } else if self.is_constant_singleton() {
1240 right
1241 } else if let MirRelationExpr::Join { inputs, .. } = &mut self {
1242 inputs.push(right);
1243 self
1244 } else {
1245 MirRelationExpr::join(vec![self, right], vec![])
1246 }
1247 }
1248
1249 /// Performs a relational equijoin among the input collections.
1250 ///
1251 /// The sequence `inputs` each describe different input collections, and the sequence `variables` describes
1252 /// equality constraints that some of their columns must satisfy. Each element in `variable` describes a set
1253 /// of pairs `(input_index, column_index)` where every value described by that set must be equal.
1254 ///
1255 /// For example, the pair `(input, column)` indexes into `inputs[input][column]`, extracting the `input`th
1256 /// input collection and for each row examining its `column`th column.
1257 ///
1258 /// # Example
1259 ///
1260 /// ```rust
1261 /// use mz_repr::{Datum, SqlColumnType, ReprRelationType, ReprScalarType};
1262 /// use mz_expr::MirRelationExpr;
1263 ///
1264 /// // A common schema for each input.
1265 /// let schema = ReprRelationType::new(vec![
1266 /// ReprScalarType::Int32.nullable(false),
1267 /// ReprScalarType::Int32.nullable(false),
1268 /// ]);
1269 ///
1270 /// // the specific data are not important here.
1271 /// let data = vec![Datum::Int32(0), Datum::Int32(1)];
1272 ///
1273 /// // Three collections that could have been different.
1274 /// let input0 = MirRelationExpr::constant(vec![data.clone()], schema.clone());
1275 /// let input1 = MirRelationExpr::constant(vec![data.clone()], schema.clone());
1276 /// let input2 = MirRelationExpr::constant(vec![data.clone()], schema.clone());
1277 ///
1278 /// // Join the three relations looking for triangles, like so.
1279 /// //
1280 /// // Output(A,B,C) := Input0(A,B), Input1(B,C), Input2(A,C)
1281 /// let joined = MirRelationExpr::join(
1282 /// vec![input0, input1, input2],
1283 /// vec![
1284 /// vec![(0,0), (2,0)], // fields A of inputs 0 and 2.
1285 /// vec![(0,1), (1,0)], // fields B of inputs 0 and 1.
1286 /// vec![(1,1), (2,1)], // fields C of inputs 1 and 2.
1287 /// ],
1288 /// );
1289 ///
1290 /// // Technically the above produces `Output(A,B,B,C,A,C)` because the columns are concatenated.
1291 /// // A projection resolves this and produces the correct output.
1292 /// let result = joined.project(vec![0, 1, 3]);
1293 /// ```
1294 pub fn join(inputs: Vec<MirRelationExpr>, variables: Vec<Vec<(usize, usize)>>) -> Self {
1295 let input_mapper = join_input_mapper::JoinInputMapper::new(&inputs);
1296
1297 let equivalences = variables
1298 .into_iter()
1299 .map(|vs| {
1300 vs.into_iter()
1301 .map(|(r, c)| input_mapper.map_expr_to_global(MirScalarExpr::column(c), r))
1302 .collect::<Vec<_>>()
1303 })
1304 .collect::<Vec<_>>();
1305
1306 Self::join_scalars(inputs, equivalences)
1307 }
1308
1309 /// Constructs a join operator from inputs and required-equal scalar expressions.
1310 pub fn join_scalars(
1311 mut inputs: Vec<MirRelationExpr>,
1312 equivalences: Vec<Vec<MirScalarExpr>>,
1313 ) -> Self {
1314 // Remove all constant inputs that are the identity for join.
1315 // They neither introduce nor modify any column references.
1316 inputs.retain(|i| !i.is_constant_singleton());
1317 MirRelationExpr::Join {
1318 inputs,
1319 equivalences,
1320 implementation: JoinImplementation::Unimplemented,
1321 }
1322 }
1323
1324 /// Perform a key-wise reduction / aggregation.
1325 ///
1326 /// The `group_key` argument indicates columns in the input collection that should
1327 /// be grouped, and `aggregates` lists aggregation functions each of which produces
1328 /// one output column in addition to the keys.
1329 pub fn reduce(
1330 self,
1331 group_key: Vec<usize>,
1332 aggregates: Vec<AggregateExpr>,
1333 expected_group_size: Option<u64>,
1334 ) -> Self {
1335 MirRelationExpr::Reduce {
1336 input: Box::new(self),
1337 group_key: group_key.into_iter().map(MirScalarExpr::column).collect(),
1338 aggregates,
1339 monotonic: false,
1340 expected_group_size,
1341 }
1342 }
1343
1344 /// Perform a key-wise reduction order by and limit.
1345 ///
1346 /// The `group_key` argument indicates columns in the input collection that should
1347 /// be grouped, the `order_key` argument indicates columns that should be further
1348 /// used to order records within groups, and the `limit` argument constrains the
1349 /// total number of records that should be produced in each group.
1350 pub fn top_k(
1351 self,
1352 group_key: Vec<usize>,
1353 order_key: Vec<ColumnOrder>,
1354 limit: Option<MirScalarExpr>,
1355 offset: usize,
1356 expected_group_size: Option<u64>,
1357 ) -> Self {
1358 MirRelationExpr::TopK {
1359 input: Box::new(self),
1360 group_key,
1361 order_key,
1362 limit,
1363 offset,
1364 expected_group_size,
1365 monotonic: false,
1366 }
1367 }
1368
1369 /// Negates the occurrences of each row.
1370 pub fn negate(self) -> Self {
1371 if let MirRelationExpr::Negate { input } = self {
1372 *input
1373 } else {
1374 MirRelationExpr::Negate {
1375 input: Box::new(self),
1376 }
1377 }
1378 }
1379
1380 /// Removes all but the first occurrence of each row.
1381 pub fn distinct(self) -> Self {
1382 let arity = self.arity();
1383 self.distinct_by((0..arity).collect())
1384 }
1385
1386 /// Removes all but the first occurrence of each key. Columns not included
1387 /// in the `group_key` are discarded.
1388 pub fn distinct_by(self, group_key: Vec<usize>) -> Self {
1389 self.reduce(group_key, vec![], None)
1390 }
1391
1392 /// Discards rows with a negative frequency.
1393 pub fn threshold(self) -> Self {
1394 if let MirRelationExpr::Threshold { .. } = &self {
1395 self
1396 } else {
1397 MirRelationExpr::Threshold {
1398 input: Box::new(self),
1399 }
1400 }
1401 }
1402
1403 /// Unions together any number inputs.
1404 ///
1405 /// If `inputs` is empty, then an empty relation of type `typ` is
1406 /// constructed.
1407 pub fn union_many(mut inputs: Vec<Self>, typ: ReprRelationType) -> Self {
1408 // Deconstruct `inputs` as `Union`s and reconstitute.
1409 let mut flat_inputs = Vec::with_capacity(inputs.len());
1410 for input in inputs {
1411 if let MirRelationExpr::Union { base, inputs } = input {
1412 flat_inputs.push(*base);
1413 flat_inputs.extend(inputs);
1414 } else {
1415 flat_inputs.push(input);
1416 }
1417 }
1418 inputs = flat_inputs;
1419 if inputs.len() == 0 {
1420 MirRelationExpr::Constant {
1421 rows: Ok(vec![]),
1422 typ,
1423 }
1424 } else if inputs.len() == 1 {
1425 inputs.into_element()
1426 } else {
1427 MirRelationExpr::Union {
1428 base: Box::new(inputs.remove(0)),
1429 inputs,
1430 }
1431 }
1432 }
1433
1434 /// Produces one collection where each row is present with the sum of its frequencies in each input.
1435 pub fn union(self, other: Self) -> Self {
1436 // Deconstruct `self` and `other` as `Union`s and reconstitute.
1437 let mut flat_inputs = Vec::with_capacity(2);
1438 if let MirRelationExpr::Union { base, inputs } = self {
1439 flat_inputs.push(*base);
1440 flat_inputs.extend(inputs);
1441 } else {
1442 flat_inputs.push(self);
1443 }
1444 if let MirRelationExpr::Union { base, inputs } = other {
1445 flat_inputs.push(*base);
1446 flat_inputs.extend(inputs);
1447 } else {
1448 flat_inputs.push(other);
1449 }
1450
1451 MirRelationExpr::Union {
1452 base: Box::new(flat_inputs.remove(0)),
1453 inputs: flat_inputs,
1454 }
1455 }
1456
1457 /// Arranges the collection by the specified columns
1458 pub fn arrange_by(self, keys: &[Vec<MirScalarExpr>]) -> Self {
1459 MirRelationExpr::ArrangeBy {
1460 input: Box::new(self),
1461 keys: keys.to_owned(),
1462 }
1463 }
1464
1465 /// Indicates if this is a constant empty collection.
1466 ///
1467 /// A false value does not mean the collection is known to be non-empty,
1468 /// only that we cannot currently determine that it is statically empty.
1469 pub fn is_empty(&self) -> bool {
1470 if let Some((Ok(rows), ..)) = self.as_const() {
1471 rows.is_empty()
1472 } else {
1473 false
1474 }
1475 }
1476
1477 /// If the expression is a negated project, return the input and the projection.
1478 pub fn is_negated_project(&self) -> Option<(&MirRelationExpr, &[usize])> {
1479 if let MirRelationExpr::Negate { input } = self {
1480 if let MirRelationExpr::Project { input, outputs } = &**input {
1481 return Some((&**input, outputs));
1482 }
1483 }
1484 if let MirRelationExpr::Project { input, outputs } = self {
1485 if let MirRelationExpr::Negate { input } = &**input {
1486 return Some((&**input, outputs));
1487 }
1488 }
1489 None
1490 }
1491
1492 /// Pretty-print this [MirRelationExpr] to a string.
1493 pub fn pretty(&self) -> String {
1494 let config = ExplainConfig::default();
1495 self.debug_explain(&config, None)
1496 }
1497
1498 /// Pretty-print this [MirRelationExpr] to a string using a custom
1499 /// [ExplainConfig] and an optionally provided [ExprHumanizer].
1500 /// This is intended for debugging and tests, not users.
1501 pub fn debug_explain(
1502 &self,
1503 config: &ExplainConfig,
1504 humanizer: Option<&dyn ExprHumanizer>,
1505 ) -> String {
1506 text_string_at(self, || PlanRenderingContext {
1507 indent: Indent::default(),
1508 humanizer: humanizer.unwrap_or(&DummyHumanizer),
1509 annotations: BTreeMap::default(),
1510 config,
1511 ambiguous_ids: BTreeSet::default(),
1512 })
1513 }
1514
1515 /// Take ownership of `self`, leaving an empty `MirRelationExpr::Constant` with the optionally
1516 /// given scalar types. The given scalar types should be `base_eq` with the types that `typ()`
1517 /// would find. Keys and nullability are ignored in the given `SqlRelationType`, and instead we set
1518 /// the best possible key and nullability, since we are making an empty collection.
1519 ///
1520 /// If `typ` is not given, then this calls `.typ()` (which is possibly expensive) to determine
1521 /// the correct type.
1522 pub fn take_safely(&mut self, typ: Option<ReprRelationType>) -> MirRelationExpr {
1523 if let Some(typ) = &typ {
1524 let self_typ = self.typ();
1525 soft_assert_no_log!(
1526 self_typ
1527 .column_types
1528 .iter()
1529 .zip_eq(typ.column_types.iter())
1530 .all(|(t1, t2)| t1.scalar_type == t2.scalar_type)
1531 );
1532 }
1533 let mut typ = typ.unwrap_or_else(|| self.typ());
1534 typ.keys = vec![vec![]];
1535 for ct in typ.column_types.iter_mut() {
1536 ct.nullable = false;
1537 }
1538 std::mem::replace(
1539 self,
1540 MirRelationExpr::Constant {
1541 rows: Ok(vec![]),
1542 typ,
1543 },
1544 )
1545 }
1546
1547 /// Take ownership of `self`, leaving an empty `MirRelationExpr::Constant` with the given scalar
1548 /// types. Nullability is ignored in the given `SqlColumnType`s, and instead we set the best
1549 /// possible nullability, since we are making an empty collection.
1550 pub fn take_safely_with_sql_col_types(&mut self, typ: Vec<SqlColumnType>) -> MirRelationExpr {
1551 self.take_safely(Some(ReprRelationType::from(&SqlRelationType::new(typ))))
1552 }
1553
1554 /// Like [`Self::take_safely_with_col_types`], but accepts `Vec<ReprColumnType>`.
1555 ///
1556 /// This is the preferred entry point for optimizer transforms, where repr
1557 /// types are the native currency. Internally converts to [`SqlColumnType`]
1558 /// and delegates to [`Self::take_safely_with_col_types`].
1559 pub fn take_safely_with_col_types(&mut self, typ: Vec<ReprColumnType>) -> MirRelationExpr {
1560 self.take_safely(Some(ReprRelationType::new(typ)))
1561 }
1562
1563 /// Take ownership of `self`, leaving an empty `MirRelationExpr::Constant` with an **incorrect** type.
1564 ///
1565 /// This should only be used if `self` is about to be dropped or otherwise overwritten.
1566 pub fn take_dangerous(&mut self) -> MirRelationExpr {
1567 let empty = MirRelationExpr::Constant {
1568 rows: Ok(vec![]),
1569 typ: ReprRelationType::new(Vec::new()),
1570 };
1571 std::mem::replace(self, empty)
1572 }
1573
1574 /// Replaces `self` with some logic applied to `self`.
1575 pub fn replace_using<F>(&mut self, logic: F)
1576 where
1577 F: FnOnce(MirRelationExpr) -> MirRelationExpr,
1578 {
1579 let empty = MirRelationExpr::Constant {
1580 rows: Ok(vec![]),
1581 typ: ReprRelationType::new(Vec::new()),
1582 };
1583 let expr = std::mem::replace(self, empty);
1584 *self = logic(expr);
1585 }
1586
1587 /// Store `self` in a `Let` and pass the corresponding `Get` to `body`.
1588 pub fn let_in<Body, E>(self, id_gen: &mut IdGen, body: Body) -> Result<MirRelationExpr, E>
1589 where
1590 Body: FnOnce(&mut IdGen, MirRelationExpr) -> Result<MirRelationExpr, E>,
1591 {
1592 if let MirRelationExpr::Get { .. } = self {
1593 // already done
1594 body(id_gen, self)
1595 } else {
1596 let id = LocalId::new(id_gen.allocate_id());
1597 let get = MirRelationExpr::Get {
1598 id: Id::Local(id),
1599 typ: self.typ(),
1600 access_strategy: AccessStrategy::UnknownOrLocal,
1601 };
1602 let body = (body)(id_gen, get)?;
1603 Ok(MirRelationExpr::Let {
1604 id,
1605 value: Box::new(self),
1606 body: Box::new(body),
1607 })
1608 }
1609 }
1610
1611 /// Return every row in `self` that does not have a matching row in the first columns of `keys_and_values`, using `default` to fill in the remaining columns
1612 /// (If `default` is a row of nulls, this is the 'outer' part of LEFT OUTER JOIN)
1613 pub fn anti_lookup<E>(
1614 self,
1615 id_gen: &mut IdGen,
1616 keys_and_values: MirRelationExpr,
1617 default: Vec<(Datum, ReprScalarType)>,
1618 ) -> Result<MirRelationExpr, E> {
1619 let (data, column_types): (Vec<_>, Vec<_>) = default
1620 .into_iter()
1621 .map(|(datum, scalar_type)| {
1622 (
1623 datum,
1624 ReprColumnType {
1625 scalar_type,
1626 nullable: datum.is_null(),
1627 },
1628 )
1629 })
1630 .unzip();
1631 assert_eq!(keys_and_values.arity() - self.arity(), data.len());
1632 self.let_in(id_gen, |_id_gen, get_keys| {
1633 let get_keys_arity = get_keys.arity();
1634 Ok(MirRelationExpr::join(
1635 vec![
1636 // all the missing keys (with count 1)
1637 keys_and_values
1638 .distinct_by((0..get_keys_arity).collect())
1639 .negate()
1640 .union(get_keys.clone().distinct()),
1641 // join with keys to get the correct counts
1642 get_keys.clone(),
1643 ],
1644 (0..get_keys_arity).map(|i| vec![(0, i), (1, i)]).collect(),
1645 )
1646 // get rid of the extra copies of columns from keys
1647 .project((0..get_keys_arity).collect())
1648 // This join is logically equivalent to
1649 // `.map(<default_expr>)`, but using a join allows for
1650 // potential predicate pushdown and elision in the
1651 // optimizer.
1652 .product(MirRelationExpr::constant(
1653 vec![data],
1654 ReprRelationType::new(column_types),
1655 )))
1656 })
1657 }
1658
1659 /// Return:
1660 /// * every row in keys_and_values
1661 /// * every row in `self` that does not have a matching row in the first columns of
1662 /// `keys_and_values`, using `default` to fill in the remaining columns
1663 /// (This is LEFT OUTER JOIN if:
1664 /// 1) `default` is a row of null
1665 /// 2) matching rows in `keys_and_values` and `self` have the same multiplicity.)
1666 pub fn lookup<E>(
1667 self,
1668 id_gen: &mut IdGen,
1669 keys_and_values: MirRelationExpr,
1670 default: Vec<(Datum<'static>, ReprScalarType)>,
1671 ) -> Result<MirRelationExpr, E> {
1672 keys_and_values.let_in(id_gen, |id_gen, get_keys_and_values| {
1673 Ok(get_keys_and_values.clone().union(self.anti_lookup(
1674 id_gen,
1675 get_keys_and_values,
1676 default,
1677 )?))
1678 })
1679 }
1680
1681 /// True iff the expression contains a `NullaryFunc::MzLogicalTimestamp`.
1682 pub fn contains_temporal(&self) -> bool {
1683 let mut contains = false;
1684 self.visit_scalars(&mut |e| contains = contains || e.contains_temporal());
1685 contains
1686 }
1687
1688 /// Fallible visitor for the [`MirScalarExpr`]s directly owned by this relation expression.
1689 ///
1690 /// The `f` visitor should not recursively descend into owned [`MirRelationExpr`]s.
1691 pub fn try_visit_scalars_mut1<F, E>(&mut self, f: &mut F) -> Result<(), E>
1692 where
1693 F: FnMut(&mut MirScalarExpr) -> Result<(), E>,
1694 {
1695 use MirRelationExpr::*;
1696 match self {
1697 Map { scalars, .. } => {
1698 for s in scalars {
1699 f(s)?;
1700 }
1701 }
1702 Filter { predicates, .. } => {
1703 for p in predicates {
1704 f(p)?;
1705 }
1706 }
1707 FlatMap { exprs, .. } => {
1708 for expr in exprs {
1709 f(expr)?;
1710 }
1711 }
1712 Join {
1713 inputs: _,
1714 equivalences,
1715 implementation,
1716 } => {
1717 for equivalence in equivalences {
1718 for expr in equivalence {
1719 f(expr)?;
1720 }
1721 }
1722 match implementation {
1723 JoinImplementation::Differential((_, start_key, _), order) => {
1724 if let Some(start_key) = start_key {
1725 for k in start_key {
1726 f(k)?;
1727 }
1728 }
1729 for (_, lookup_key, _) in order {
1730 for k in lookup_key {
1731 f(k)?;
1732 }
1733 }
1734 }
1735 JoinImplementation::DeltaQuery(paths) => {
1736 for path in paths {
1737 for (_, lookup_key, _) in path {
1738 for k in lookup_key {
1739 f(k)?;
1740 }
1741 }
1742 }
1743 }
1744 JoinImplementation::IndexedFilter(_coll_id, _idx_id, index_key, _) => {
1745 for k in index_key {
1746 f(k)?;
1747 }
1748 }
1749 JoinImplementation::Unimplemented => {} // No scalar exprs
1750 }
1751 }
1752 ArrangeBy { keys, .. } => {
1753 for key in keys {
1754 for s in key {
1755 f(s)?;
1756 }
1757 }
1758 }
1759 Reduce {
1760 group_key,
1761 aggregates,
1762 ..
1763 } => {
1764 for s in group_key {
1765 f(s)?;
1766 }
1767 for agg in aggregates {
1768 f(&mut agg.expr)?;
1769 }
1770 }
1771 TopK { limit, .. } => {
1772 if let Some(s) = limit {
1773 f(s)?;
1774 }
1775 }
1776 Constant { .. }
1777 | Get { .. }
1778 | Let { .. }
1779 | LetRec { .. }
1780 | Project { .. }
1781 | Negate { .. }
1782 | Threshold { .. }
1783 | Union { .. } => (),
1784 }
1785 Ok(())
1786 }
1787
1788 /// Fallible mutable visitor for the [`MirScalarExpr`]s in the [`MirRelationExpr`] subtree
1789 /// rooted at `self`.
1790 ///
1791 /// Note that this does not recurse into [`MirRelationExpr`] subtrees within [`MirScalarExpr`]
1792 /// nodes.
1793 pub fn try_visit_scalars_mut<F, E>(&mut self, f: &mut F) -> Result<(), E>
1794 where
1795 F: FnMut(&mut MirScalarExpr) -> Result<(), E>,
1796 E: From<RecursionLimitError>,
1797 {
1798 self.try_visit_mut_post(&mut |expr| expr.try_visit_scalars_mut1(f))
1799 }
1800
1801 /// Infallible mutable visitor for the [`MirScalarExpr`]s in the [`MirRelationExpr`] subtree
1802 /// rooted at `self`.
1803 ///
1804 /// Note that this does not recurse into [`MirRelationExpr`] subtrees within [`MirScalarExpr`]
1805 /// nodes.
1806 pub fn visit_scalars_mut<F>(&mut self, f: &mut F)
1807 where
1808 F: FnMut(&mut MirScalarExpr),
1809 {
1810 self.try_visit_scalars_mut(&mut |s| {
1811 f(s);
1812 Ok::<_, RecursionLimitError>(())
1813 })
1814 .expect("Unexpected error in `visit_scalars_mut` call");
1815 }
1816
1817 /// Fallible visitor for the [`MirScalarExpr`]s directly owned by this relation expression.
1818 ///
1819 /// The `f` visitor should not recursively descend into owned [`MirRelationExpr`]s.
1820 pub fn try_visit_scalars_1<F, E>(&self, f: &mut F) -> Result<(), E>
1821 where
1822 F: FnMut(&MirScalarExpr) -> Result<(), E>,
1823 {
1824 use MirRelationExpr::*;
1825 match self {
1826 Map { scalars, .. } => {
1827 for s in scalars {
1828 f(s)?;
1829 }
1830 }
1831 Filter { predicates, .. } => {
1832 for p in predicates {
1833 f(p)?;
1834 }
1835 }
1836 FlatMap { exprs, .. } => {
1837 for expr in exprs {
1838 f(expr)?;
1839 }
1840 }
1841 Join {
1842 inputs: _,
1843 equivalences,
1844 implementation,
1845 } => {
1846 for equivalence in equivalences {
1847 for expr in equivalence {
1848 f(expr)?;
1849 }
1850 }
1851 match implementation {
1852 JoinImplementation::Differential((_, start_key, _), order) => {
1853 if let Some(start_key) = start_key {
1854 for k in start_key {
1855 f(k)?;
1856 }
1857 }
1858 for (_, lookup_key, _) in order {
1859 for k in lookup_key {
1860 f(k)?;
1861 }
1862 }
1863 }
1864 JoinImplementation::DeltaQuery(paths) => {
1865 for path in paths {
1866 for (_, lookup_key, _) in path {
1867 for k in lookup_key {
1868 f(k)?;
1869 }
1870 }
1871 }
1872 }
1873 JoinImplementation::IndexedFilter(_coll_id, _idx_id, index_key, _) => {
1874 for k in index_key {
1875 f(k)?;
1876 }
1877 }
1878 JoinImplementation::Unimplemented => {} // No scalar exprs
1879 }
1880 }
1881 ArrangeBy { keys, .. } => {
1882 for key in keys {
1883 for s in key {
1884 f(s)?;
1885 }
1886 }
1887 }
1888 Reduce {
1889 group_key,
1890 aggregates,
1891 ..
1892 } => {
1893 for s in group_key {
1894 f(s)?;
1895 }
1896 for agg in aggregates {
1897 f(&agg.expr)?;
1898 }
1899 }
1900 TopK { limit, .. } => {
1901 if let Some(s) = limit {
1902 f(s)?;
1903 }
1904 }
1905 Constant { .. }
1906 | Get { .. }
1907 | Let { .. }
1908 | LetRec { .. }
1909 | Project { .. }
1910 | Negate { .. }
1911 | Threshold { .. }
1912 | Union { .. } => (),
1913 }
1914 Ok(())
1915 }
1916
1917 /// Fallible immutable visitor for the [`MirScalarExpr`]s in the [`MirRelationExpr`] subtree
1918 /// rooted at `self`.
1919 ///
1920 /// Note that this does not recurse into [`MirRelationExpr`] subtrees within [`MirScalarExpr`]
1921 /// nodes.
1922 pub fn try_visit_scalars<F, E>(&self, f: &mut F) -> Result<(), E>
1923 where
1924 F: FnMut(&MirScalarExpr) -> Result<(), E>,
1925 E: From<RecursionLimitError>,
1926 {
1927 self.try_visit_post(&mut |expr| expr.try_visit_scalars_1(f))
1928 }
1929
1930 /// Infallible immutable visitor for the [`MirScalarExpr`]s in the [`MirRelationExpr`] subtree
1931 /// rooted at `self`.
1932 ///
1933 /// Note that this does not recurse into [`MirRelationExpr`] subtrees within [`MirScalarExpr`]
1934 /// nodes.
1935 pub fn visit_scalars<F>(&self, f: &mut F)
1936 where
1937 F: FnMut(&MirScalarExpr),
1938 {
1939 self.try_visit_scalars(&mut |s| {
1940 f(s);
1941 Ok::<_, RecursionLimitError>(())
1942 })
1943 .expect("Unexpected error in `visit_scalars` call");
1944 }
1945
1946 /// Clears the contents of `self` even if it's so deep that simply dropping it would cause a
1947 /// stack overflow in `drop_in_place`.
1948 ///
1949 /// Leaves `self` in an unusable state, so this should only be used if `self` is about to be
1950 /// dropped or otherwise overwritten.
1951 pub fn destroy_carefully(&mut self) {
1952 let mut todo = vec![self.take_dangerous()];
1953 while let Some(mut expr) = todo.pop() {
1954 for child in expr.children_mut() {
1955 todo.push(child.take_dangerous());
1956 }
1957 }
1958 }
1959
1960 /// Computes the size (total number of nodes) and maximum depth of a MirRelationExpr for
1961 /// debug printing purposes.
1962 pub fn debug_size_and_depth(&self) -> (usize, usize) {
1963 let mut size = 0;
1964 let mut max_depth = 0;
1965 let mut todo = vec![(self, 1)];
1966 while let Some((expr, depth)) = todo.pop() {
1967 size += 1;
1968 max_depth = max(max_depth, depth);
1969 todo.extend(expr.children().map(|c| (c, depth + 1)));
1970 }
1971 (size, max_depth)
1972 }
1973
1974 /// The MirRelationExpr is considered potentially expensive if and only if
1975 /// at least one of the following conditions is true:
1976 ///
1977 /// - It contains at least one MirScalarExpr with a function call.
1978 /// - It contains at least one FlatMap or a Reduce operator.
1979 /// - We run into a RecursionLimitError while analyzing the expression.
1980 ///
1981 /// !!!WARNING!!!: this method has an HirRelationExpr counterpart. The two
1982 /// should be kept in sync w.r.t. HIR ⇒ MIR lowering!
1983 pub fn could_run_expensive_function(&self) -> bool {
1984 let mut result = false;
1985 use MirRelationExpr::*;
1986 use MirScalarExpr::*;
1987 if let Err(_) = self.try_visit_scalars::<_, RecursionLimitError>(&mut |scalar| {
1988 result |= match scalar {
1989 Column(_, _) | Literal(_, _) | CallUnmaterializable(_) | If { .. } => false,
1990 // Function calls are considered expensive
1991 CallUnary { .. } | CallBinary { .. } | CallVariadic { .. } => true,
1992 };
1993 Ok(())
1994 }) {
1995 // Conservatively set `true` if on RecursionLimitError.
1996 result = true;
1997 }
1998 self.visit_pre(|e: &MirRelationExpr| {
1999 // FlatMap has a table function; Reduce has an aggregate function.
2000 // Other constructs use MirScalarExpr to run a function
2001 result |= matches!(e, FlatMap { .. } | Reduce { .. });
2002 });
2003 result
2004 }
2005
2006 /// Hash to an u64 using Rust's default Hasher. (Which is a somewhat slower, but better Hasher
2007 /// than what `Hashable::hashed` would give us.)
2008 pub fn hash_to_u64(&self) -> u64 {
2009 let mut h = DefaultHasher::new();
2010 self.hash(&mut h);
2011 h.finish()
2012 }
2013}
2014
2015// `LetRec` helpers
2016impl MirRelationExpr {
2017 /// True when `expr` contains a `LetRec` AST node.
2018 pub fn is_recursive(self: &MirRelationExpr) -> bool {
2019 let mut worklist = vec![self];
2020 while let Some(expr) = worklist.pop() {
2021 if let MirRelationExpr::LetRec { .. } = expr {
2022 return true;
2023 }
2024 worklist.extend(expr.children());
2025 }
2026 false
2027 }
2028
2029 /// Return the number of sub-expressions in the tree (including self).
2030 pub fn size(&self) -> usize {
2031 let mut size = 0;
2032 self.visit_pre(|_| size += 1);
2033 size
2034 }
2035
2036 /// Given the ids and values of a LetRec, it computes the subset of ids that are used across
2037 /// iterations. These are those ids that have a reference before they are defined, when reading
2038 /// all the bindings in order.
2039 ///
2040 /// For example:
2041 /// ```SQL
2042 /// WITH MUTUALLY RECURSIVE
2043 /// x(...) AS f(z),
2044 /// y(...) AS g(x),
2045 /// z(...) AS h(y)
2046 /// ...;
2047 /// ```
2048 /// Here, only `z` is returned, because `x` and `y` are referenced only within the same
2049 /// iteration.
2050 ///
2051 /// Note that if a binding references itself, that is also returned.
2052 pub fn recursive_ids(ids: &[LocalId], values: &[MirRelationExpr]) -> BTreeSet<LocalId> {
2053 let mut used_across_iterations = BTreeSet::new();
2054 let mut defined = BTreeSet::new();
2055 for (binding_id, value) in itertools::zip_eq(ids.iter(), values.iter()) {
2056 value.visit_pre(|expr| {
2057 if let MirRelationExpr::Get {
2058 id: Local(get_id), ..
2059 } = expr
2060 {
2061 // If we haven't seen a definition for it yet, then this will refer
2062 // to the previous iteration.
2063 // The `ids.contains` part of the condition is needed to exclude
2064 // those ids that are not really in this LetRec, but either an inner
2065 // or outer one.
2066 if !defined.contains(get_id) && ids.contains(get_id) {
2067 used_across_iterations.insert(*get_id);
2068 }
2069 }
2070 });
2071 defined.insert(*binding_id);
2072 }
2073 used_across_iterations
2074 }
2075
2076 /// Replaces `LetRec` nodes with a stack of `Let` nodes.
2077 ///
2078 /// In each `Let` binding, uses of `Get` in `value` that are not at strictly greater
2079 /// identifiers are rewritten to be the constant collection.
2080 /// This makes the computation perform exactly "one" iteration.
2081 ///
2082 /// This was used only temporarily while developing `LetRec`.
2083 pub fn make_nonrecursive(self: &mut MirRelationExpr) {
2084 let mut deadlist = BTreeSet::new();
2085 let mut worklist = vec![self];
2086 while let Some(expr) = worklist.pop() {
2087 if let MirRelationExpr::LetRec {
2088 ids,
2089 values,
2090 limits: _,
2091 body,
2092 } = expr
2093 {
2094 let ids_values = values
2095 .drain(..)
2096 .zip_eq(ids)
2097 .map(|(value, id)| (*id, value))
2098 .collect::<Vec<_>>();
2099 *expr = body.take_dangerous();
2100 for (id, mut value) in ids_values.into_iter().rev() {
2101 // Remove references to potentially recursive identifiers.
2102 deadlist.insert(id);
2103 value.visit_pre_mut(|e| {
2104 if let MirRelationExpr::Get {
2105 id: crate::Id::Local(id),
2106 typ,
2107 ..
2108 } = e
2109 {
2110 let typ = typ.clone();
2111 if deadlist.contains(id) {
2112 e.take_safely(Some(typ));
2113 }
2114 }
2115 });
2116 *expr = MirRelationExpr::Let {
2117 id,
2118 value: Box::new(value),
2119 body: Box::new(expr.take_dangerous()),
2120 };
2121 }
2122 worklist.push(expr);
2123 } else {
2124 worklist.extend(expr.children_mut().rev());
2125 }
2126 }
2127 }
2128
2129 /// For each Id `id'` referenced in `expr`, if it is larger or equal than `id`, then record in
2130 /// `expire_whens` that when `id'` is redefined, then we should expire the information that
2131 /// we are holding about `id`. Call `do_expirations` with `expire_whens` at each Id
2132 /// redefinition.
2133 ///
2134 /// IMPORTANT: Relies on the numbering of Ids to be what `renumber_bindings` gives.
2135 pub fn collect_expirations(
2136 id: LocalId,
2137 expr: &MirRelationExpr,
2138 expire_whens: &mut BTreeMap<LocalId, Vec<LocalId>>,
2139 ) {
2140 expr.visit_pre(|e| {
2141 if let MirRelationExpr::Get {
2142 id: Id::Local(referenced_id),
2143 ..
2144 } = e
2145 {
2146 // The following check needs `renumber_bindings` to have run recently
2147 if referenced_id >= &id {
2148 expire_whens
2149 .entry(*referenced_id)
2150 .or_insert_with(Vec::new)
2151 .push(id);
2152 }
2153 }
2154 });
2155 }
2156
2157 /// Call this function when `id` is redefined. It modifies `id_infos` by removing information
2158 /// about such Ids whose information depended on the earlier definition of `id`, according to
2159 /// `expire_whens`. Also modifies `expire_whens`: it removes the currently processed entry.
2160 pub fn do_expirations<I>(
2161 redefined_id: LocalId,
2162 expire_whens: &mut BTreeMap<LocalId, Vec<LocalId>>,
2163 id_infos: &mut BTreeMap<LocalId, I>,
2164 ) -> Vec<(LocalId, I)> {
2165 let mut expired_infos = Vec::new();
2166 if let Some(expirations) = expire_whens.remove(&redefined_id) {
2167 for expired_id in expirations.into_iter() {
2168 if let Some(offer) = id_infos.remove(&expired_id) {
2169 expired_infos.push((expired_id, offer));
2170 }
2171 }
2172 }
2173 expired_infos
2174 }
2175}
2176/// Augment non-nullability of columns, by observing either
2177/// 1. Predicates that explicitly test for null values, and
2178/// 2. Columns that if null would make a predicate be null.
2179pub fn non_nullable_columns(predicates: &[MirScalarExpr]) -> BTreeSet<usize> {
2180 let mut nonnull_required_columns = BTreeSet::new();
2181 for predicate in predicates {
2182 // Add any columns that being null would force the predicate to be null.
2183 // Should that happen, the row would be discarded.
2184 predicate.non_null_requirements(&mut nonnull_required_columns);
2185
2186 /*
2187 Test for explicit checks that a column is non-null.
2188
2189 This analysis is ad hoc, and will miss things:
2190
2191 materialize=> create table a(x int, y int);
2192 CREATE TABLE
2193 materialize=> explain with(types) select x from a where (y=x and y is not null) or x is not null;
2194 Optimized Plan
2195 --------------------------------------------------------------------------------------------------------
2196 Explained Query: +
2197 Project (#0) // { types: "(integer?)" } +
2198 Filter ((#0) IS NOT NULL OR ((#1) IS NOT NULL AND (#0 = #1))) // { types: "(integer?, integer?)" }+
2199 Get materialize.public.a // { types: "(integer?, integer?)" } +
2200 +
2201 Source materialize.public.a +
2202 filter=(((#0) IS NOT NULL OR ((#1) IS NOT NULL AND (#0 = #1)))) +
2203
2204 (1 row)
2205 */
2206
2207 if let MirScalarExpr::CallUnary {
2208 func: UnaryFunc::Not(scalar_func::Not),
2209 expr,
2210 } = predicate
2211 {
2212 if let MirScalarExpr::CallUnary {
2213 func: UnaryFunc::IsNull(scalar_func::IsNull),
2214 expr,
2215 } = &**expr
2216 {
2217 if let MirScalarExpr::Column(c, _name) = &**expr {
2218 nonnull_required_columns.insert(*c);
2219 }
2220 }
2221 }
2222 }
2223
2224 nonnull_required_columns
2225}
2226
2227impl CollectionPlan for MirRelationExpr {
2228 /// Collects the global collections that this MIR expression directly depends on, i.e., that it
2229 /// has a `Get` for. (It does _not_ traverse view definitions transitively.)
2230 ///
2231 /// !!!WARNING!!!: this method has an HirRelationExpr counterpart. The two
2232 /// should be kept in sync w.r.t. HIR ⇒ MIR lowering!
2233 fn depends_on_into(&self, out: &mut BTreeSet<GlobalId>) {
2234 if let MirRelationExpr::Get {
2235 id: Id::Global(id), ..
2236 } = self
2237 {
2238 out.insert(*id);
2239 }
2240 self.visit_children(|expr| expr.depends_on_into(out))
2241 }
2242}
2243
2244impl MirRelationExpr {
2245 /// Iterates through references to child expressions.
2246 pub fn children(&self) -> impl DoubleEndedIterator<Item = &Self> {
2247 let mut first = None;
2248 let mut second = None;
2249 let mut rest = None;
2250 let mut last = None;
2251
2252 use MirRelationExpr::*;
2253 match self {
2254 Constant { .. } | Get { .. } => (),
2255 Let { value, body, .. } => {
2256 first = Some(&**value);
2257 second = Some(&**body);
2258 }
2259 LetRec { values, body, .. } => {
2260 rest = Some(values);
2261 last = Some(&**body);
2262 }
2263 Project { input, .. }
2264 | Map { input, .. }
2265 | FlatMap { input, .. }
2266 | Filter { input, .. }
2267 | Reduce { input, .. }
2268 | TopK { input, .. }
2269 | Negate { input }
2270 | Threshold { input }
2271 | ArrangeBy { input, .. } => {
2272 first = Some(&**input);
2273 }
2274 Join { inputs, .. } => {
2275 rest = Some(inputs);
2276 }
2277 Union { base, inputs } => {
2278 first = Some(&**base);
2279 rest = Some(inputs);
2280 }
2281 }
2282
2283 first
2284 .into_iter()
2285 .chain(second)
2286 .chain(rest.into_iter().flatten())
2287 .chain(last)
2288 }
2289
2290 /// Iterates through mutable references to child expressions.
2291 pub fn children_mut(&mut self) -> impl DoubleEndedIterator<Item = &mut Self> {
2292 let mut first = None;
2293 let mut second = None;
2294 let mut rest = None;
2295 let mut last = None;
2296
2297 use MirRelationExpr::*;
2298 match self {
2299 Constant { .. } | Get { .. } => (),
2300 Let { value, body, .. } => {
2301 first = Some(&mut **value);
2302 second = Some(&mut **body);
2303 }
2304 LetRec { values, body, .. } => {
2305 rest = Some(values);
2306 last = Some(&mut **body);
2307 }
2308 Project { input, .. }
2309 | Map { input, .. }
2310 | FlatMap { input, .. }
2311 | Filter { input, .. }
2312 | Reduce { input, .. }
2313 | TopK { input, .. }
2314 | Negate { input }
2315 | Threshold { input }
2316 | ArrangeBy { input, .. } => {
2317 first = Some(&mut **input);
2318 }
2319 Join { inputs, .. } => {
2320 rest = Some(inputs);
2321 }
2322 Union { base, inputs } => {
2323 first = Some(&mut **base);
2324 rest = Some(inputs);
2325 }
2326 }
2327
2328 first
2329 .into_iter()
2330 .chain(second)
2331 .chain(rest.into_iter().flatten())
2332 .chain(last)
2333 }
2334
2335 /// Iterative pre-order visitor.
2336 pub fn visit_pre<'a, F: FnMut(&'a Self)>(&'a self, mut f: F) {
2337 let mut worklist = vec![self];
2338 while let Some(expr) = worklist.pop() {
2339 f(expr);
2340 worklist.extend(expr.children().rev());
2341 }
2342 }
2343
2344 /// Iterative pre-order visitor.
2345 pub fn visit_pre_mut<F: FnMut(&mut Self)>(&mut self, mut f: F) {
2346 let mut worklist = vec![self];
2347 while let Some(expr) = worklist.pop() {
2348 f(expr);
2349 worklist.extend(expr.children_mut().rev());
2350 }
2351 }
2352
2353 /// Return a vector of references to the subtrees of this expression
2354 /// in post-visit order (the last element is `&self`).
2355 pub fn post_order_vec(&self) -> Vec<&Self> {
2356 let mut stack = vec![self];
2357 let mut result = vec![];
2358 while let Some(expr) = stack.pop() {
2359 result.push(expr);
2360 stack.extend(expr.children());
2361 }
2362 result.reverse();
2363 result
2364 }
2365}
2366
2367impl VisitChildren<Self> for MirRelationExpr {
2368 fn visit_children<F>(&self, mut f: F)
2369 where
2370 F: FnMut(&Self),
2371 {
2372 for child in self.children() {
2373 f(child)
2374 }
2375 }
2376
2377 fn visit_mut_children<F>(&mut self, mut f: F)
2378 where
2379 F: FnMut(&mut Self),
2380 {
2381 for child in self.children_mut() {
2382 f(child)
2383 }
2384 }
2385
2386 fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
2387 where
2388 F: FnMut(&Self) -> Result<(), E>,
2389 E: From<RecursionLimitError>,
2390 {
2391 for child in self.children() {
2392 f(child)?
2393 }
2394 Ok(())
2395 }
2396
2397 fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
2398 where
2399 F: FnMut(&mut Self) -> Result<(), E>,
2400 E: From<RecursionLimitError>,
2401 {
2402 for child in self.children_mut() {
2403 f(child)?
2404 }
2405 Ok(())
2406 }
2407}
2408
2409/// Specification for an ordering by a column.
2410#[derive(
2411 Debug,
2412 Clone,
2413 Copy,
2414 Eq,
2415 PartialEq,
2416 Ord,
2417 PartialOrd,
2418 Serialize,
2419 Deserialize,
2420 Hash,
2421 MzReflect
2422)]
2423pub struct ColumnOrder {
2424 /// The column index.
2425 pub column: usize,
2426 /// Whether to sort in descending order.
2427 #[serde(default)]
2428 pub desc: bool,
2429 /// Whether to sort nulls last.
2430 #[serde(default)]
2431 pub nulls_last: bool,
2432}
2433
2434impl Columnation for ColumnOrder {
2435 type InnerRegion = CopyRegion<Self>;
2436}
2437
2438impl<'a, M> fmt::Display for HumanizedExpr<'a, ColumnOrder, M>
2439where
2440 M: HumanizerMode,
2441{
2442 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
2443 // If you modify this, then please also attend to Display for ColumnOrderWithExpr!
2444 write!(
2445 f,
2446 "{} {} {}",
2447 self.child(&self.expr.column),
2448 if self.expr.desc { "desc" } else { "asc" },
2449 if self.expr.nulls_last {
2450 "nulls_last"
2451 } else {
2452 "nulls_first"
2453 },
2454 )
2455 }
2456}
2457
2458/// Describes an aggregation expression.
2459#[derive(
2460 Clone,
2461 Debug,
2462 Eq,
2463 PartialEq,
2464 Ord,
2465 PartialOrd,
2466 Serialize,
2467 Deserialize,
2468 Hash,
2469 MzReflect
2470)]
2471pub struct AggregateExpr {
2472 /// Names the aggregation function.
2473 pub func: AggregateFunc,
2474 /// An expression which extracts from each row the input to `func`.
2475 pub expr: MirScalarExpr,
2476 /// Should the aggregation be applied only to distinct results in each group.
2477 #[serde(default)]
2478 pub distinct: bool,
2479}
2480
2481impl AggregateExpr {
2482 /// Computes the type of this `AggregateExpr`.
2483 pub fn sql_typ(&self, column_types: &[SqlColumnType]) -> SqlColumnType {
2484 self.func.output_sql_type(self.expr.sql_typ(column_types))
2485 }
2486
2487 /// Computes the type of this `AggregateExpr`.
2488 pub fn typ(&self, column_types: &[ReprColumnType]) -> ReprColumnType {
2489 self.func.output_type(self.expr.typ(column_types))
2490 }
2491
2492 /// Returns whether the expression has a constant result.
2493 pub fn is_constant(&self) -> bool {
2494 match self.func {
2495 AggregateFunc::MaxNumeric
2496 | AggregateFunc::MaxInt16
2497 | AggregateFunc::MaxInt32
2498 | AggregateFunc::MaxInt64
2499 | AggregateFunc::MaxUInt16
2500 | AggregateFunc::MaxUInt32
2501 | AggregateFunc::MaxUInt64
2502 | AggregateFunc::MaxMzTimestamp
2503 | AggregateFunc::MaxFloat32
2504 | AggregateFunc::MaxFloat64
2505 | AggregateFunc::MaxBool
2506 | AggregateFunc::MaxString
2507 | AggregateFunc::MaxDate
2508 | AggregateFunc::MaxTimestamp
2509 | AggregateFunc::MaxTimestampTz
2510 | AggregateFunc::MaxInterval
2511 | AggregateFunc::MaxTime
2512 | AggregateFunc::MinNumeric
2513 | AggregateFunc::MinInt16
2514 | AggregateFunc::MinInt32
2515 | AggregateFunc::MinInt64
2516 | AggregateFunc::MinUInt16
2517 | AggregateFunc::MinUInt32
2518 | AggregateFunc::MinUInt64
2519 | AggregateFunc::MinMzTimestamp
2520 | AggregateFunc::MinFloat32
2521 | AggregateFunc::MinFloat64
2522 | AggregateFunc::MinBool
2523 | AggregateFunc::MinString
2524 | AggregateFunc::MinDate
2525 | AggregateFunc::MinTimestamp
2526 | AggregateFunc::MinTimestampTz
2527 | AggregateFunc::MinInterval
2528 | AggregateFunc::MinTime
2529 | AggregateFunc::Any
2530 | AggregateFunc::All
2531 | AggregateFunc::Dummy => self.expr.is_literal(),
2532 AggregateFunc::Count => self.expr.is_literal_null(),
2533 AggregateFunc::SumInt16
2534 | AggregateFunc::SumInt32
2535 | AggregateFunc::SumInt64
2536 | AggregateFunc::SumUInt16
2537 | AggregateFunc::SumUInt32
2538 | AggregateFunc::SumUInt64
2539 | AggregateFunc::SumFloat32
2540 | AggregateFunc::SumFloat64
2541 | AggregateFunc::SumNumeric
2542 | AggregateFunc::JsonbAgg { .. }
2543 | AggregateFunc::JsonbObjectAgg { .. }
2544 | AggregateFunc::MapAgg { .. }
2545 | AggregateFunc::ArrayConcat { .. }
2546 | AggregateFunc::ListConcat { .. }
2547 | AggregateFunc::StringAgg { .. }
2548 | AggregateFunc::RowNumber { .. }
2549 | AggregateFunc::Rank { .. }
2550 | AggregateFunc::DenseRank { .. }
2551 | AggregateFunc::LagLead { .. }
2552 | AggregateFunc::FirstValue { .. }
2553 | AggregateFunc::LastValue { .. }
2554 | AggregateFunc::FusedValueWindowFunc { .. }
2555 | AggregateFunc::WindowAggregate { .. }
2556 | AggregateFunc::FusedWindowAggregate { .. } => self.expr.is_literal_err(),
2557 }
2558 }
2559
2560 /// Returns an expression that computes `self` on a group that has exactly one row.
2561 /// Instead of performing a `Reduce` with `self`, one can perform a `Map` with the expression
2562 /// returned by `on_unique`, which is cheaper. (See `ReduceElision`.)
2563 pub fn on_unique(&self, input_type: &[ReprColumnType]) -> MirScalarExpr {
2564 match &self.func {
2565 // Count is one if non-null, and zero if null.
2566 AggregateFunc::Count => self
2567 .expr
2568 .clone()
2569 .call_unary(UnaryFunc::IsNull(crate::func::IsNull))
2570 .if_then_else(
2571 MirScalarExpr::literal_ok(Datum::Int64(0), ReprScalarType::Int64),
2572 MirScalarExpr::literal_ok(Datum::Int64(1), ReprScalarType::Int64),
2573 ),
2574
2575 // SumInt16 takes Int16s as input, but outputs Int64s.
2576 AggregateFunc::SumInt16 => self
2577 .expr
2578 .clone()
2579 .call_unary(UnaryFunc::CastInt16ToInt64(scalar_func::CastInt16ToInt64)),
2580
2581 // SumInt32 takes Int32s as input, but outputs Int64s.
2582 AggregateFunc::SumInt32 => self
2583 .expr
2584 .clone()
2585 .call_unary(UnaryFunc::CastInt32ToInt64(scalar_func::CastInt32ToInt64)),
2586
2587 // SumInt64 takes Int64s as input, but outputs numerics.
2588 AggregateFunc::SumInt64 => self.expr.clone().call_unary(UnaryFunc::CastInt64ToNumeric(
2589 scalar_func::CastInt64ToNumeric(Some(NumericMaxScale::ZERO)),
2590 )),
2591
2592 // SumUInt16 takes UInt16s as input, but outputs UInt64s.
2593 AggregateFunc::SumUInt16 => self.expr.clone().call_unary(
2594 UnaryFunc::CastUint16ToUint64(scalar_func::CastUint16ToUint64),
2595 ),
2596
2597 // SumUInt32 takes UInt32s as input, but outputs UInt64s.
2598 AggregateFunc::SumUInt32 => self.expr.clone().call_unary(
2599 UnaryFunc::CastUint32ToUint64(scalar_func::CastUint32ToUint64),
2600 ),
2601
2602 // SumUInt64 takes UInt64s as input, but outputs numerics.
2603 AggregateFunc::SumUInt64 => {
2604 self.expr.clone().call_unary(UnaryFunc::CastUint64ToNumeric(
2605 scalar_func::CastUint64ToNumeric(Some(NumericMaxScale::ZERO)),
2606 ))
2607 }
2608
2609 // JsonbAgg takes _anything_ as input, but must output a Jsonb array.
2610 AggregateFunc::JsonbAgg { .. } => MirScalarExpr::call_variadic(
2611 JsonbBuildArray,
2612 vec![
2613 self.expr
2614 .clone()
2615 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0))),
2616 ],
2617 ),
2618
2619 // JsonbAgg takes _anything_ as input, but must output a Jsonb object.
2620 AggregateFunc::JsonbObjectAgg { .. } => {
2621 let record = self
2622 .expr
2623 .clone()
2624 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2625 MirScalarExpr::call_variadic(
2626 JsonbBuildObject,
2627 (0..2)
2628 .map(|i| {
2629 record
2630 .clone()
2631 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(i)))
2632 })
2633 .collect(),
2634 )
2635 }
2636
2637 AggregateFunc::MapAgg { value_type, .. } => {
2638 let record = self
2639 .expr
2640 .clone()
2641 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2642 MirScalarExpr::call_variadic(
2643 MapBuild {
2644 value_type: value_type.clone(),
2645 },
2646 (0..2)
2647 .map(|i| {
2648 record
2649 .clone()
2650 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(i)))
2651 })
2652 .collect(),
2653 )
2654 }
2655
2656 // StringAgg takes nested records of strings and outputs a string
2657 AggregateFunc::StringAgg { .. } => self
2658 .expr
2659 .clone()
2660 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)))
2661 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0))),
2662
2663 // ListConcat and ArrayConcat take a single level of records and output a list containing exactly 1 element
2664 AggregateFunc::ListConcat { .. } | AggregateFunc::ArrayConcat { .. } => self
2665 .expr
2666 .clone()
2667 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0))),
2668
2669 // RowNumber, Rank, DenseRank take a list of records and output a list containing exactly 1 element
2670 AggregateFunc::RowNumber { .. } => {
2671 self.on_unique_ranking_window_funcs(input_type, "?row_number?")
2672 }
2673 AggregateFunc::Rank { .. } => self.on_unique_ranking_window_funcs(input_type, "?rank?"),
2674 AggregateFunc::DenseRank { .. } => {
2675 self.on_unique_ranking_window_funcs(input_type, "?dense_rank?")
2676 }
2677
2678 // The input type for LagLead is ((OriginalRow, (InputValue, Offset, Default)), OrderByExprs...)
2679 AggregateFunc::LagLead { lag_lead, .. } => {
2680 let tuple = self
2681 .expr
2682 .clone()
2683 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2684
2685 // Get the overall return type
2686 let return_type_with_orig_row = self
2687 .typ(input_type)
2688 .scalar_type
2689 .unwrap_list_element_type()
2690 .clone();
2691 let lag_lead_return_type =
2692 return_type_with_orig_row.unwrap_record_element_type()[0].clone();
2693
2694 // Extract the original row
2695 let original_row = tuple
2696 .clone()
2697 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2698
2699 // Extract the encoded args
2700 let encoded_args =
2701 tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
2702
2703 let (result_expr, column_name) =
2704 Self::on_unique_lag_lead(lag_lead, encoded_args, lag_lead_return_type.clone());
2705
2706 MirScalarExpr::call_variadic(
2707 ListCreate {
2708 elem_type: SqlScalarType::from_repr(&return_type_with_orig_row),
2709 },
2710 vec![MirScalarExpr::call_variadic(
2711 RecordCreate {
2712 field_names: vec![column_name, ColumnName::from("?record?")],
2713 },
2714 vec![result_expr, original_row],
2715 )],
2716 )
2717 }
2718
2719 // The input type for FirstValue is ((OriginalRow, InputValue), OrderByExprs...)
2720 AggregateFunc::FirstValue { window_frame, .. } => {
2721 let tuple = self
2722 .expr
2723 .clone()
2724 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2725
2726 // Get the overall return type
2727 let return_type_with_orig_row = self
2728 .typ(input_type)
2729 .scalar_type
2730 .unwrap_list_element_type()
2731 .clone();
2732 let first_value_return_type =
2733 return_type_with_orig_row.unwrap_record_element_type()[0].clone();
2734
2735 // Extract the original row
2736 let original_row = tuple
2737 .clone()
2738 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2739
2740 // Extract the input value
2741 let arg = tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
2742
2743 let (result_expr, column_name) = Self::on_unique_first_value_last_value(
2744 window_frame,
2745 arg,
2746 first_value_return_type,
2747 );
2748
2749 MirScalarExpr::call_variadic(
2750 ListCreate {
2751 elem_type: SqlScalarType::from_repr(&return_type_with_orig_row),
2752 },
2753 vec![MirScalarExpr::call_variadic(
2754 RecordCreate {
2755 field_names: vec![column_name, ColumnName::from("?record?")],
2756 },
2757 vec![result_expr, original_row],
2758 )],
2759 )
2760 }
2761
2762 // The input type for LastValue is ((OriginalRow, InputValue), OrderByExprs...)
2763 AggregateFunc::LastValue { window_frame, .. } => {
2764 let tuple = self
2765 .expr
2766 .clone()
2767 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2768
2769 // Get the overall return type
2770 let return_type_with_orig_row = self
2771 .typ(input_type)
2772 .scalar_type
2773 .unwrap_list_element_type()
2774 .clone();
2775 let last_value_return_type =
2776 return_type_with_orig_row.unwrap_record_element_type()[0].clone();
2777
2778 // Extract the original row
2779 let original_row = tuple
2780 .clone()
2781 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2782
2783 // Extract the input value
2784 let arg = tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
2785
2786 let (result_expr, column_name) = Self::on_unique_first_value_last_value(
2787 window_frame,
2788 arg,
2789 last_value_return_type,
2790 );
2791
2792 MirScalarExpr::call_variadic(
2793 ListCreate {
2794 elem_type: SqlScalarType::from_repr(&return_type_with_orig_row),
2795 },
2796 vec![MirScalarExpr::call_variadic(
2797 RecordCreate {
2798 field_names: vec![column_name, ColumnName::from("?record?")],
2799 },
2800 vec![result_expr, original_row],
2801 )],
2802 )
2803 }
2804
2805 // The input type for window aggs is ((OriginalRow, InputValue), OrderByExprs...)
2806 // See an example MIR in `window_func_applied_to`.
2807 AggregateFunc::WindowAggregate {
2808 wrapped_aggregate,
2809 window_frame,
2810 order_by: _,
2811 } => {
2812 // TODO: deduplicate code between the various window function cases.
2813
2814 let tuple = self
2815 .expr
2816 .clone()
2817 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2818
2819 // Get the overall return type
2820 let return_type = self
2821 .typ(input_type)
2822 .scalar_type
2823 .unwrap_list_element_type()
2824 .clone();
2825 let window_agg_return_type = return_type.unwrap_record_element_type()[0].clone();
2826
2827 // Extract the original row
2828 let original_row = tuple
2829 .clone()
2830 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2831
2832 // Extract the input value
2833 let arg_expr = tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
2834
2835 let (result, column_name) = Self::on_unique_window_agg(
2836 window_frame,
2837 arg_expr,
2838 input_type,
2839 window_agg_return_type,
2840 wrapped_aggregate,
2841 );
2842
2843 MirScalarExpr::call_variadic(
2844 ListCreate {
2845 elem_type: SqlScalarType::from_repr(&return_type),
2846 },
2847 vec![MirScalarExpr::call_variadic(
2848 RecordCreate {
2849 field_names: vec![column_name, ColumnName::from("?record?")],
2850 },
2851 vec![result, original_row],
2852 )],
2853 )
2854 }
2855
2856 // The input type is ((OriginalRow, (Arg1, Arg2, ...)), OrderByExprs...)
2857 AggregateFunc::FusedWindowAggregate {
2858 wrapped_aggregates,
2859 order_by: _,
2860 window_frame,
2861 } => {
2862 // Throw away OrderByExprs
2863 let tuple = self
2864 .expr
2865 .clone()
2866 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2867
2868 // Extract the original row
2869 let original_row = tuple
2870 .clone()
2871 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2872
2873 // Extract the args of the fused call
2874 let all_args = tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
2875
2876 let return_type_with_orig_row = self
2877 .typ(input_type)
2878 .scalar_type
2879 .unwrap_list_element_type()
2880 .clone();
2881
2882 let all_func_return_types =
2883 return_type_with_orig_row.unwrap_record_element_type()[0].clone();
2884 let mut func_result_exprs = Vec::new();
2885 let mut col_names = Vec::new();
2886 for (idx, wrapped_aggr) in wrapped_aggregates.iter().enumerate() {
2887 let arg = all_args
2888 .clone()
2889 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(idx)));
2890 let return_type =
2891 all_func_return_types.unwrap_record_element_type()[idx].clone();
2892 let (result, column_name) = Self::on_unique_window_agg(
2893 window_frame,
2894 arg,
2895 input_type,
2896 return_type,
2897 wrapped_aggr,
2898 );
2899 func_result_exprs.push(result);
2900 col_names.push(column_name);
2901 }
2902
2903 MirScalarExpr::call_variadic(
2904 ListCreate {
2905 elem_type: SqlScalarType::from_repr(&return_type_with_orig_row),
2906 },
2907 vec![MirScalarExpr::call_variadic(
2908 RecordCreate {
2909 field_names: vec![
2910 ColumnName::from("?fused_window_aggr?"),
2911 ColumnName::from("?record?"),
2912 ],
2913 },
2914 vec![
2915 MirScalarExpr::call_variadic(
2916 RecordCreate {
2917 field_names: col_names,
2918 },
2919 func_result_exprs,
2920 ),
2921 original_row,
2922 ],
2923 )],
2924 )
2925 }
2926
2927 // The input type is ((OriginalRow, (Args1, Args2, ...)), OrderByExprs...)
2928 AggregateFunc::FusedValueWindowFunc {
2929 funcs,
2930 order_by: outer_order_by,
2931 } => {
2932 // Throw away OrderByExprs
2933 let tuple = self
2934 .expr
2935 .clone()
2936 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2937
2938 // Extract the original row
2939 let original_row = tuple
2940 .clone()
2941 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2942
2943 // Extract the encoded args of the fused call
2944 let all_encoded_args =
2945 tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
2946
2947 let return_type_with_orig_row = self
2948 .typ(input_type)
2949 .scalar_type
2950 .unwrap_list_element_type()
2951 .clone();
2952
2953 let all_func_return_types =
2954 return_type_with_orig_row.unwrap_record_element_type()[0].clone();
2955 let mut func_result_exprs = Vec::new();
2956 let mut col_names = Vec::new();
2957 for (idx, func) in funcs.iter().enumerate() {
2958 let args_for_func = all_encoded_args
2959 .clone()
2960 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(idx)));
2961 let return_type_for_func =
2962 all_func_return_types.unwrap_record_element_type()[idx].clone();
2963 let (result, column_name) = match func {
2964 AggregateFunc::LagLead {
2965 lag_lead,
2966 order_by,
2967 ignore_nulls: _,
2968 } => {
2969 assert_eq!(order_by, outer_order_by);
2970 Self::on_unique_lag_lead(lag_lead, args_for_func, return_type_for_func)
2971 }
2972 AggregateFunc::FirstValue {
2973 window_frame,
2974 order_by,
2975 } => {
2976 assert_eq!(order_by, outer_order_by);
2977 Self::on_unique_first_value_last_value(
2978 window_frame,
2979 args_for_func,
2980 return_type_for_func,
2981 )
2982 }
2983 AggregateFunc::LastValue {
2984 window_frame,
2985 order_by,
2986 } => {
2987 assert_eq!(order_by, outer_order_by);
2988 Self::on_unique_first_value_last_value(
2989 window_frame,
2990 args_for_func,
2991 return_type_for_func,
2992 )
2993 }
2994 _ => panic!("unknown function in FusedValueWindowFunc"),
2995 };
2996 func_result_exprs.push(result);
2997 col_names.push(column_name);
2998 }
2999
3000 MirScalarExpr::call_variadic(
3001 ListCreate {
3002 elem_type: SqlScalarType::from_repr(&return_type_with_orig_row),
3003 },
3004 vec![MirScalarExpr::call_variadic(
3005 RecordCreate {
3006 field_names: vec![
3007 ColumnName::from("?fused_value_window_func?"),
3008 ColumnName::from("?record?"),
3009 ],
3010 },
3011 vec![
3012 MirScalarExpr::call_variadic(
3013 RecordCreate {
3014 field_names: col_names,
3015 },
3016 func_result_exprs,
3017 ),
3018 original_row,
3019 ],
3020 )],
3021 )
3022 }
3023
3024 // All other variants should return the argument to the aggregation.
3025 AggregateFunc::MaxNumeric
3026 | AggregateFunc::MaxInt16
3027 | AggregateFunc::MaxInt32
3028 | AggregateFunc::MaxInt64
3029 | AggregateFunc::MaxUInt16
3030 | AggregateFunc::MaxUInt32
3031 | AggregateFunc::MaxUInt64
3032 | AggregateFunc::MaxMzTimestamp
3033 | AggregateFunc::MaxFloat32
3034 | AggregateFunc::MaxFloat64
3035 | AggregateFunc::MaxBool
3036 | AggregateFunc::MaxString
3037 | AggregateFunc::MaxDate
3038 | AggregateFunc::MaxTimestamp
3039 | AggregateFunc::MaxTimestampTz
3040 | AggregateFunc::MaxInterval
3041 | AggregateFunc::MaxTime
3042 | AggregateFunc::MinNumeric
3043 | AggregateFunc::MinInt16
3044 | AggregateFunc::MinInt32
3045 | AggregateFunc::MinInt64
3046 | AggregateFunc::MinUInt16
3047 | AggregateFunc::MinUInt32
3048 | AggregateFunc::MinUInt64
3049 | AggregateFunc::MinMzTimestamp
3050 | AggregateFunc::MinFloat32
3051 | AggregateFunc::MinFloat64
3052 | AggregateFunc::MinBool
3053 | AggregateFunc::MinString
3054 | AggregateFunc::MinDate
3055 | AggregateFunc::MinTimestamp
3056 | AggregateFunc::MinTimestampTz
3057 | AggregateFunc::MinInterval
3058 | AggregateFunc::MinTime
3059 | AggregateFunc::SumFloat32
3060 | AggregateFunc::SumFloat64
3061 | AggregateFunc::SumNumeric
3062 | AggregateFunc::Any
3063 | AggregateFunc::All
3064 | AggregateFunc::Dummy => self.expr.clone(),
3065 }
3066 }
3067
3068 /// `on_unique` for ROW_NUMBER, RANK, DENSE_RANK
3069 fn on_unique_ranking_window_funcs(
3070 &self,
3071 input_type: &[ReprColumnType],
3072 col_name: &str,
3073 ) -> MirScalarExpr {
3074 let sql_input_type: Vec<SqlColumnType> =
3075 input_type.iter().map(SqlColumnType::from_repr).collect();
3076 let list = self
3077 .expr
3078 .clone()
3079 // extract the list within the record
3080 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
3081
3082 // extract the expression within the list
3083 let record = MirScalarExpr::call_variadic(
3084 ListIndex,
3085 vec![
3086 list,
3087 MirScalarExpr::literal_ok(Datum::Int64(1), ReprScalarType::Int64),
3088 ],
3089 );
3090
3091 MirScalarExpr::call_variadic(
3092 ListCreate {
3093 elem_type: self
3094 .sql_typ(&sql_input_type)
3095 .scalar_type
3096 .unwrap_list_element_type()
3097 .clone(),
3098 },
3099 vec![MirScalarExpr::call_variadic(
3100 RecordCreate {
3101 field_names: vec![ColumnName::from(col_name), ColumnName::from("?record?")],
3102 },
3103 vec![
3104 MirScalarExpr::literal_ok(Datum::Int64(1), ReprScalarType::Int64),
3105 record,
3106 ],
3107 )],
3108 )
3109 }
3110
3111 /// `on_unique` for `lag` and `lead`
3112 fn on_unique_lag_lead(
3113 lag_lead: &LagLeadType,
3114 encoded_args: MirScalarExpr,
3115 return_type: ReprScalarType,
3116 ) -> (MirScalarExpr, ColumnName) {
3117 let expr = encoded_args
3118 .clone()
3119 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
3120 let offset = encoded_args
3121 .clone()
3122 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
3123 let default_value =
3124 encoded_args.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(2)));
3125
3126 // In this case, the window always has only one element, so if the offset is not null and
3127 // not zero, the default value should be returned instead.
3128 let value = offset
3129 .clone()
3130 .call_binary(
3131 MirScalarExpr::literal_ok(Datum::Int32(0), ReprScalarType::Int32),
3132 crate::func::Eq,
3133 )
3134 .if_then_else(expr, default_value);
3135 let result_expr = offset
3136 .call_unary(UnaryFunc::IsNull(crate::func::IsNull))
3137 .if_then_else(MirScalarExpr::literal_null(return_type), value);
3138
3139 let column_name = ColumnName::from(match lag_lead {
3140 LagLeadType::Lag => "?lag?",
3141 LagLeadType::Lead => "?lead?",
3142 });
3143
3144 (result_expr, column_name)
3145 }
3146
3147 /// `on_unique` for `first_value` and `last_value`
3148 fn on_unique_first_value_last_value(
3149 window_frame: &WindowFrame,
3150 arg: MirScalarExpr,
3151 return_type: ReprScalarType,
3152 ) -> (MirScalarExpr, ColumnName) {
3153 // If the window frame includes the current (single) row, return its value, null otherwise
3154 let result_expr = if window_frame.includes_current_row() {
3155 arg
3156 } else {
3157 MirScalarExpr::literal_null(return_type)
3158 };
3159 (result_expr, ColumnName::from("?first_value?"))
3160 }
3161
3162 /// `on_unique` for window aggregations
3163 fn on_unique_window_agg(
3164 window_frame: &WindowFrame,
3165 arg_expr: MirScalarExpr,
3166 input_type: &[ReprColumnType],
3167 return_type: ReprScalarType,
3168 wrapped_aggr: &AggregateFunc,
3169 ) -> (MirScalarExpr, ColumnName) {
3170 // If the window frame includes the current (single) row, evaluate the wrapped aggregate on
3171 // that row. Otherwise, return the default value for the aggregate.
3172 let result_expr = if window_frame.includes_current_row() {
3173 AggregateExpr {
3174 func: wrapped_aggr.clone(),
3175 expr: arg_expr,
3176 distinct: false, // We have just one input element; DISTINCT doesn't matter.
3177 }
3178 .on_unique(input_type)
3179 } else {
3180 MirScalarExpr::literal_ok(wrapped_aggr.default(), return_type)
3181 };
3182 (result_expr, ColumnName::from("?window_agg?"))
3183 }
3184
3185 /// Returns whether the expression is COUNT(*) or not. Note that
3186 /// when we define the count builtin in sql::func, we convert
3187 /// COUNT(*) to COUNT(true), making it indistinguishable from
3188 /// literal COUNT(true), but we prefer to consider this as the
3189 /// former.
3190 ///
3191 /// (HIR has the same `is_count_asterisk`.)
3192 pub fn is_count_asterisk(&self) -> bool {
3193 self.func == AggregateFunc::Count && self.expr.is_literal_true() && !self.distinct
3194 }
3195}
3196
3197/// Describe a join implementation in dataflow.
3198#[derive(
3199 Clone,
3200 Debug,
3201 Eq,
3202 PartialEq,
3203 Ord,
3204 PartialOrd,
3205 Serialize,
3206 Deserialize,
3207 Hash,
3208 MzReflect
3209)]
3210pub enum JoinImplementation {
3211 /// Perform a sequence of binary differential dataflow joins.
3212 ///
3213 /// The first argument indicates
3214 /// 1) the index of the starting collection,
3215 /// 2) if it should be arranged, the keys to arrange it by, and
3216 /// 3) the characteristics of the starting collection (for EXPLAINing).
3217 /// The sequence that follows lists other relation indexes, and the key for
3218 /// the arrangement we should use when joining it in.
3219 /// The JoinInputCharacteristics are for EXPLAINing the characteristics that
3220 /// were used for join ordering.
3221 ///
3222 /// Each collection index should occur exactly once, either as the starting collection
3223 /// or somewhere in the list.
3224 Differential(
3225 (
3226 usize,
3227 Option<Vec<MirScalarExpr>>,
3228 Option<JoinInputCharacteristics>,
3229 ),
3230 Vec<(usize, Vec<MirScalarExpr>, Option<JoinInputCharacteristics>)>,
3231 ),
3232 /// Perform independent delta query dataflows for each input.
3233 ///
3234 /// The argument is a sequence of plans, for the input collections in order.
3235 /// Each plan starts from the corresponding index, and then in sequence joins
3236 /// against collections identified by index and with the specified arrangement key.
3237 /// The JoinInputCharacteristics are for EXPLAINing the characteristics that were
3238 /// used for join ordering.
3239 DeltaQuery(Vec<Vec<(usize, Vec<MirScalarExpr>, Option<JoinInputCharacteristics>)>>),
3240 /// Join a user-created index with a constant collection to speed up the evaluation of a
3241 /// predicate such as `(f1 = 3 AND f2 = 5) OR (f1 = 7 AND f2 = 9)`.
3242 /// This gets translated to a Differential join during MIR -> LIR lowering, but we still want
3243 /// to represent it in MIR, because the fast path detection wants to match on this.
3244 ///
3245 /// Consists of (`<coll_id>`, `<index_id>`, `<index_key>`, `<constants>`)
3246 IndexedFilter(
3247 GlobalId,
3248 GlobalId,
3249 Vec<MirScalarExpr>,
3250 #[mzreflect(ignore)] Vec<Row>,
3251 ),
3252 /// No implementation yet selected.
3253 Unimplemented,
3254}
3255
3256impl Default for JoinImplementation {
3257 fn default() -> Self {
3258 JoinImplementation::Unimplemented
3259 }
3260}
3261
3262impl JoinImplementation {
3263 /// Returns `true` iff the value is not [`JoinImplementation::Unimplemented`].
3264 pub fn is_implemented(&self) -> bool {
3265 match self {
3266 Self::Unimplemented => false,
3267 _ => true,
3268 }
3269 }
3270
3271 /// Returns an optional implementation name if the value is not [`JoinImplementation::Unimplemented`].
3272 pub fn name(&self) -> Option<&'static str> {
3273 match self {
3274 Self::Differential(..) => Some("differential"),
3275 Self::DeltaQuery(..) => Some("delta"),
3276 Self::IndexedFilter(..) => Some("indexed_filter"),
3277 Self::Unimplemented => None,
3278 }
3279 }
3280}
3281
3282/// Characteristics of a join order candidate collection.
3283///
3284/// A candidate is described by a collection and a key, and may have various liabilities.
3285/// Primarily, the candidate may risk substantial inflation of records, which is something
3286/// that concerns us greatly. Additionally, the candidate may be unarranged, and we would
3287/// prefer candidates that do not require additional memory. Finally, we prefer lower id
3288/// collections in the interest of consistent tie-breaking. For more characteristics, see
3289/// comments on individual fields.
3290///
3291/// This has more than one version. `new` instantiates the appropriate version based on a
3292/// feature flag.
3293#[derive(
3294 Eq,
3295 PartialEq,
3296 Ord,
3297 PartialOrd,
3298 Debug,
3299 Clone,
3300 Serialize,
3301 Deserialize,
3302 Hash,
3303 MzReflect
3304)]
3305pub enum JoinInputCharacteristics {
3306 /// Old version, with `enable_join_prioritize_arranged` turned off.
3307 V1(JoinInputCharacteristicsV1),
3308 /// Newer version, with `enable_join_prioritize_arranged` turned on.
3309 V2(JoinInputCharacteristicsV2),
3310}
3311
3312impl JoinInputCharacteristics {
3313 /// Creates a new instance with the given characteristics.
3314 pub fn new(
3315 unique_key: bool,
3316 key_length: usize,
3317 arranged: bool,
3318 cardinality: Option<usize>,
3319 filters: FilterCharacteristics,
3320 input: usize,
3321 enable_join_prioritize_arranged: bool,
3322 ) -> Self {
3323 if enable_join_prioritize_arranged {
3324 Self::V2(JoinInputCharacteristicsV2::new(
3325 unique_key,
3326 key_length,
3327 arranged,
3328 cardinality,
3329 filters,
3330 input,
3331 ))
3332 } else {
3333 Self::V1(JoinInputCharacteristicsV1::new(
3334 unique_key,
3335 key_length,
3336 arranged,
3337 cardinality,
3338 filters,
3339 input,
3340 ))
3341 }
3342 }
3343
3344 /// Turns the instance into a String to be printed in EXPLAIN.
3345 pub fn explain(&self) -> String {
3346 match self {
3347 Self::V1(jic) => jic.explain(),
3348 Self::V2(jic) => jic.explain(),
3349 }
3350 }
3351
3352 /// Whether the join input described by `self` is arranged.
3353 pub fn arranged(&self) -> bool {
3354 match self {
3355 Self::V1(jic) => jic.arranged,
3356 Self::V2(jic) => jic.arranged,
3357 }
3358 }
3359
3360 /// Returns the `FilterCharacteristics` for the join input described by `self`.
3361 pub fn filters(&mut self) -> &mut FilterCharacteristics {
3362 match self {
3363 Self::V1(jic) => &mut jic.filters,
3364 Self::V2(jic) => &mut jic.filters,
3365 }
3366 }
3367}
3368
3369/// Newer version of `JoinInputCharacteristics`, with `enable_join_prioritize_arranged` turned on.
3370#[derive(
3371 Eq,
3372 PartialEq,
3373 Ord,
3374 PartialOrd,
3375 Debug,
3376 Clone,
3377 Serialize,
3378 Deserialize,
3379 Hash,
3380 MzReflect
3381)]
3382pub struct JoinInputCharacteristicsV2 {
3383 /// An excellent indication that record count will not increase.
3384 pub unique_key: bool,
3385 /// Cross joins are bad.
3386 /// (`key_length > 0` also implies that it is not a cross join. However, we need to note cross
3387 /// joins in a separate field, because not being a cross join is more important than `arranged`,
3388 /// but otherwise `key_length` is less important than `arranged`.)
3389 pub not_cross: bool,
3390 /// Indicates that there will be no additional in-memory footprint.
3391 pub arranged: bool,
3392 /// A weaker signal that record count will not increase.
3393 pub key_length: usize,
3394 /// Estimated cardinality (lower is better)
3395 pub cardinality: Option<std::cmp::Reverse<usize>>,
3396 /// Characteristics of the filter that is applied at this input.
3397 pub filters: FilterCharacteristics,
3398 /// We want to prefer input earlier in the input list, for stability of ordering.
3399 pub input: std::cmp::Reverse<usize>,
3400}
3401
3402impl JoinInputCharacteristicsV2 {
3403 /// Creates a new instance with the given characteristics.
3404 pub fn new(
3405 unique_key: bool,
3406 key_length: usize,
3407 arranged: bool,
3408 cardinality: Option<usize>,
3409 filters: FilterCharacteristics,
3410 input: usize,
3411 ) -> Self {
3412 Self {
3413 unique_key,
3414 not_cross: key_length > 0,
3415 arranged,
3416 key_length,
3417 cardinality: cardinality.map(std::cmp::Reverse),
3418 filters,
3419 input: std::cmp::Reverse(input),
3420 }
3421 }
3422
3423 /// Turns the instance into a String to be printed in EXPLAIN.
3424 pub fn explain(&self) -> String {
3425 let mut e = "".to_owned();
3426 if self.unique_key {
3427 e.push_str("U");
3428 }
3429 // Don't need to print `not_cross`, because that is visible in the printed key.
3430 // if !self.not_cross {
3431 // e.push_str("C");
3432 // }
3433 for _ in 0..self.key_length {
3434 e.push_str("K");
3435 }
3436 if self.arranged {
3437 e.push_str("A");
3438 }
3439 if let Some(std::cmp::Reverse(cardinality)) = self.cardinality {
3440 e.push_str(&format!("|{cardinality}|"));
3441 }
3442 e.push_str(&self.filters.explain());
3443 e
3444 }
3445}
3446
3447/// Old version of `JoinInputCharacteristics`, with `enable_join_prioritize_arranged` turned off.
3448#[derive(
3449 Eq,
3450 PartialEq,
3451 Ord,
3452 PartialOrd,
3453 Debug,
3454 Clone,
3455 Serialize,
3456 Deserialize,
3457 Hash,
3458 MzReflect
3459)]
3460pub struct JoinInputCharacteristicsV1 {
3461 /// An excellent indication that record count will not increase.
3462 pub unique_key: bool,
3463 /// A weaker signal that record count will not increase.
3464 pub key_length: usize,
3465 /// Indicates that there will be no additional in-memory footprint.
3466 pub arranged: bool,
3467 /// Estimated cardinality (lower is better)
3468 pub cardinality: Option<std::cmp::Reverse<usize>>,
3469 /// Characteristics of the filter that is applied at this input.
3470 pub filters: FilterCharacteristics,
3471 /// We want to prefer input earlier in the input list, for stability of ordering.
3472 pub input: std::cmp::Reverse<usize>,
3473}
3474
3475impl JoinInputCharacteristicsV1 {
3476 /// Creates a new instance with the given characteristics.
3477 pub fn new(
3478 unique_key: bool,
3479 key_length: usize,
3480 arranged: bool,
3481 cardinality: Option<usize>,
3482 filters: FilterCharacteristics,
3483 input: usize,
3484 ) -> Self {
3485 Self {
3486 unique_key,
3487 key_length,
3488 arranged,
3489 cardinality: cardinality.map(std::cmp::Reverse),
3490 filters,
3491 input: std::cmp::Reverse(input),
3492 }
3493 }
3494
3495 /// Turns the instance into a String to be printed in EXPLAIN.
3496 pub fn explain(&self) -> String {
3497 let mut e = "".to_owned();
3498 if self.unique_key {
3499 e.push_str("U");
3500 }
3501 for _ in 0..self.key_length {
3502 e.push_str("K");
3503 }
3504 if self.arranged {
3505 e.push_str("A");
3506 }
3507 if let Some(std::cmp::Reverse(cardinality)) = self.cardinality {
3508 e.push_str(&format!("|{cardinality}|"));
3509 }
3510 e.push_str(&self.filters.explain());
3511 e
3512 }
3513}
3514
3515/// Instructions for finishing the result of a query.
3516///
3517/// The primary reason for the existence of this structure and attendant code
3518/// is that SQL's ORDER BY requires sorting rows (as already implied by the
3519/// keywords), whereas much of the rest of SQL is defined in terms of unordered
3520/// multisets. But as it turns out, the same idea can be used to optimize
3521/// trivial peeks.
3522///
3523/// The generic parameters are for accommodating prepared statement parameters in
3524/// `limit` and `offset`: the planner can hold these fields as HirScalarExpr long enough to call
3525/// `bind_parameters` on them.
3526#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
3527pub struct RowSetFinishing<L = NonNeg<i64>, O = usize> {
3528 /// Order rows by the given columns.
3529 pub order_by: Vec<ColumnOrder>,
3530 /// Include only as many rows (after offset).
3531 pub limit: Option<L>,
3532 /// Omit as many rows.
3533 pub offset: O,
3534 /// Include only given columns.
3535 pub project: Vec<usize>,
3536}
3537
3538impl<L> RowSetFinishing<L> {
3539 /// Returns a trivial finishing, i.e., that does nothing to the result set.
3540 pub fn trivial(arity: usize) -> RowSetFinishing<L> {
3541 RowSetFinishing {
3542 order_by: Vec::new(),
3543 limit: None,
3544 offset: 0,
3545 project: (0..arity).collect(),
3546 }
3547 }
3548 /// True if the finishing does nothing to any result set.
3549 pub fn is_trivial(&self, arity: usize) -> bool {
3550 self.limit.is_none()
3551 && self.order_by.is_empty()
3552 && self.offset == 0
3553 && self.project.iter().copied().eq(0..arity)
3554 }
3555 /// True if the finishing does not require an ORDER BY.
3556 ///
3557 /// LIMIT and OFFSET without an ORDER BY _are_ streamable: without an
3558 /// explicit ordering we will skip an arbitrary bag of elements and return
3559 /// the first arbitrary elements in the remaining bag. The result semantics
3560 /// are still correct but maybe surprising for some users.
3561 pub fn is_streamable(&self, arity: usize) -> bool {
3562 self.order_by.is_empty() && self.project.iter().copied().eq(0..arity)
3563 }
3564}
3565
3566impl RowSetFinishing<NonNeg<i64>, usize> {
3567 /// The number of rows needed from before the finishing to evaluate the finishing:
3568 /// offset + limit.
3569 ///
3570 /// If it returns None, then we need all the rows.
3571 pub fn num_rows_needed(&self) -> Option<usize> {
3572 self.limit
3573 .as_ref()
3574 .map(|l| usize::cast_from(u64::from(l.clone())) + self.offset)
3575 }
3576}
3577
3578impl RowSetFinishing {
3579 /// Applies finishing actions to a [`RowCollection`], and reports the total
3580 /// time it took to run.
3581 ///
3582 /// Returns a [`RowCollectionIter`] that contains all of the response data, as
3583 /// well as the size of the response in bytes.
3584 pub fn finish(
3585 &self,
3586 rows: RowCollection,
3587 max_result_size: u64,
3588 max_returned_query_size: Option<u64>,
3589 duration_histogram: &Histogram,
3590 ) -> Result<(RowCollectionIter, usize), String> {
3591 let now = Instant::now();
3592 let result = self.finish_inner(rows, max_result_size, max_returned_query_size);
3593 let duration = now.elapsed();
3594 duration_histogram.observe(duration.as_secs_f64());
3595
3596 result
3597 }
3598
3599 /// Implementation for [`RowSetFinishing::finish`].
3600 fn finish_inner(
3601 &self,
3602 rows: RowCollection,
3603 max_result_size: u64,
3604 max_returned_query_size: Option<u64>,
3605 ) -> Result<(RowCollectionIter, usize), String> {
3606 // How much additional memory is required to make a sorted view.
3607 let sorted_view_mem = rows.entries().saturating_mul(std::mem::size_of::<usize>());
3608 let required_memory = rows.byte_len().saturating_add(sorted_view_mem);
3609
3610 // Bail if creating the sorted view would require us to use too much memory.
3611 if required_memory > usize::cast_from(max_result_size) {
3612 let max_bytes = ByteSize::b(max_result_size);
3613 return Err(format!("result exceeds max size of {max_bytes}",));
3614 }
3615
3616 let sorted_view = rows;
3617 let mut iter = sorted_view
3618 .into_row_iter()
3619 .apply_offset(self.offset)
3620 .with_projection(self.project.clone());
3621
3622 if let Some(limit) = self.limit {
3623 let limit = u64::from(limit);
3624 let limit = usize::cast_from(limit);
3625 iter = iter.with_limit(limit);
3626 };
3627
3628 // TODO(parkmycar): Re-think how we can calculate the total response size without
3629 // having to iterate through the entire collection of Rows, while still
3630 // respecting the LIMIT, OFFSET, and projections.
3631 //
3632 // Note: It feels a bit bad always calculating the response size, but we almost
3633 // always need it to either check the `max_returned_query_size`, or for reporting
3634 // in the query history.
3635 let response_size: usize = iter.clone().map(|row| row.data().len()).sum();
3636
3637 // Bail if we would end up returning more data to the client than they can support.
3638 if let Some(max) = max_returned_query_size {
3639 if response_size > usize::cast_from(max) {
3640 let max_bytes = ByteSize::b(max);
3641 return Err(format!("result exceeds max size of {max_bytes}"));
3642 }
3643 }
3644
3645 Ok((iter, response_size))
3646 }
3647}
3648
3649/// A [RowSetFinishing] that can be repeatedly applied to batches of updates (in
3650/// a [RowCollection]) and keeps track of the remaining limit, offset, and cap
3651/// on query result size.
3652#[derive(Debug)]
3653pub struct RowSetFinishingIncremental {
3654 /// Include only as many rows (after offset).
3655 pub remaining_limit: Option<usize>,
3656 /// Omit as many rows.
3657 pub remaining_offset: usize,
3658 /// The maximum allowed result size, as requested by the client.
3659 pub max_returned_query_size: Option<u64>,
3660 /// Tracks our remaining allowed budget for result size.
3661 pub remaining_max_returned_query_size: Option<u64>,
3662 /// Include only given columns.
3663 pub project: Vec<usize>,
3664}
3665
3666impl RowSetFinishingIncremental {
3667 /// Turns the given [RowSetFinishing] into a [RowSetFinishingIncremental].
3668 /// Can only be used when [is_streamable](RowSetFinishing::is_streamable) is
3669 /// `true`.
3670 ///
3671 /// # Panics
3672 ///
3673 /// Panics if the result is not streamable, that is it has an ORDER BY.
3674 pub fn new(
3675 offset: usize,
3676 limit: Option<NonNeg<i64>>,
3677 project: Vec<usize>,
3678 max_returned_query_size: Option<u64>,
3679 ) -> Self {
3680 let limit = limit.map(|l| {
3681 let l = u64::from(l);
3682 let l = usize::cast_from(l);
3683 l
3684 });
3685
3686 RowSetFinishingIncremental {
3687 remaining_limit: limit,
3688 remaining_offset: offset,
3689 max_returned_query_size,
3690 remaining_max_returned_query_size: max_returned_query_size,
3691 project,
3692 }
3693 }
3694
3695 /// Applies finishing actions to the given [`RowCollection`], and reports
3696 /// the total time it took to run.
3697 ///
3698 /// Returns a [`RowCollectionIter`] that contains all of the response
3699 /// data.
3700 pub fn finish_incremental(
3701 &mut self,
3702 rows: RowCollection,
3703 max_result_size: u64,
3704 duration_histogram: &Histogram,
3705 ) -> Result<RowCollectionIter, String> {
3706 let now = Instant::now();
3707 let result = self.finish_incremental_inner(rows, max_result_size);
3708 let duration = now.elapsed();
3709 duration_histogram.observe(duration.as_secs_f64());
3710
3711 result
3712 }
3713
3714 fn finish_incremental_inner(
3715 &mut self,
3716 rows: RowCollection,
3717 max_result_size: u64,
3718 ) -> Result<RowCollectionIter, String> {
3719 // How much additional memory is required to make a sorted view.
3720 let sorted_view_mem = rows.entries().saturating_mul(std::mem::size_of::<usize>());
3721 let required_memory = rows.byte_len().saturating_add(sorted_view_mem);
3722
3723 // Bail if creating the sorted view would require us to use too much memory.
3724 if required_memory > usize::cast_from(max_result_size) {
3725 let max_bytes = ByteSize::b(max_result_size);
3726 return Err(format!("total result exceeds max size of {max_bytes}",));
3727 }
3728
3729 let batch_num_rows = rows.count();
3730
3731 let sorted_view = rows;
3732 let mut iter = sorted_view
3733 .into_row_iter()
3734 .apply_offset(self.remaining_offset)
3735 .with_projection(self.project.clone());
3736
3737 if let Some(limit) = self.remaining_limit {
3738 iter = iter.with_limit(limit);
3739 };
3740
3741 self.remaining_offset = self.remaining_offset.saturating_sub(batch_num_rows);
3742 if let Some(remaining_limit) = self.remaining_limit.as_mut() {
3743 *remaining_limit -= iter.count();
3744 }
3745
3746 // TODO(parkmycar): Re-think how we can calculate the total response size without
3747 // having to iterate through the entire collection of Rows, while still
3748 // respecting the LIMIT, OFFSET, and projections.
3749 //
3750 // Note: It feels a bit bad always calculating the response size, but we almost
3751 // always need it to either check the `max_returned_query_size`, or for reporting
3752 // in the query history.
3753 let response_size: usize = iter.clone().map(|row| row.data().len()).sum();
3754
3755 // Bail if we would end up returning more data to the client than they can support.
3756 if let Some(max) = &mut self.remaining_max_returned_query_size {
3757 if let Some(remaining) = max.checked_sub(response_size.cast_into()) {
3758 *max = remaining;
3759 } else {
3760 let max_bytes = ByteSize::b(self.max_returned_query_size.expect("known to exist"));
3761 return Err(format!("total result exceeds max size of {max_bytes}"));
3762 }
3763 }
3764
3765 Ok(iter)
3766 }
3767}
3768
3769/// Compares two rows columnwise, using [compare_columns].
3770///
3771/// Compared to the naive implementation, this allows sharing some memory and implements some
3772/// optimizations that avoid unnecessary row unpacking.
3773#[derive(Debug, Clone)]
3774pub struct RowComparator<O: AsRef<[ColumnOrder]> = Vec<ColumnOrder>> {
3775 order: O,
3776 /// Invariant: all column references in the order are less than this limit.
3777 /// This allows for partial unpacking of rows.
3778 limit: usize,
3779 left_vec: RefCell<DatumVec>,
3780 right_vec: RefCell<DatumVec>,
3781}
3782
3783impl<O: AsRef<[ColumnOrder]>> RowComparator<O> {
3784 /// Create a new row comparator from the given column ordering.
3785 pub fn new(order: O) -> Self {
3786 let limit = order
3787 .as_ref()
3788 .iter()
3789 .map(|o| o.column + 1)
3790 .max()
3791 .unwrap_or(0);
3792 Self {
3793 order,
3794 limit,
3795 left_vec: Default::default(),
3796 right_vec: Default::default(),
3797 }
3798 }
3799
3800 /// Compare two (references to) rows.
3801 pub fn compare_rows(
3802 &self,
3803 left_row: &RowRef,
3804 right_row: &RowRef,
3805 tiebreaker: impl Fn() -> Ordering,
3806 ) -> Ordering {
3807 let order = if self.limit == 0 {
3808 Ordering::Equal
3809 } else {
3810 // These borrows should never fail, since this struct is non-sync and this function
3811 // is non-recursive.
3812 let mut left_ref = self.left_vec.borrow_mut();
3813 let mut right_ref = self.right_vec.borrow_mut();
3814 let left_cols = left_ref.borrow_with_limit(left_row, self.limit);
3815 let right_cols = right_ref.borrow_with_limit(right_row, self.limit);
3816 compare_columns(self.order.as_ref(), &left_cols, &right_cols, || {
3817 Ordering::Equal
3818 })
3819 };
3820 // Tiebreak without the vecs borrowed, in case that recursively invokes this function.
3821 order.then_with(tiebreaker)
3822 }
3823}
3824
3825/// Compare `left` and `right` using `order`. If that doesn't produce a strict
3826/// ordering, call `tiebreaker`.
3827pub fn compare_columns<F>(
3828 order: &[ColumnOrder],
3829 left: &[Datum],
3830 right: &[Datum],
3831 tiebreaker: F,
3832) -> Ordering
3833where
3834 F: Fn() -> Ordering,
3835{
3836 for order in order {
3837 let cmp = match (&left[order.column], &right[order.column]) {
3838 (Datum::Null, Datum::Null) => Ordering::Equal,
3839 (Datum::Null, _) => {
3840 if order.nulls_last {
3841 Ordering::Greater
3842 } else {
3843 Ordering::Less
3844 }
3845 }
3846 (_, Datum::Null) => {
3847 if order.nulls_last {
3848 Ordering::Less
3849 } else {
3850 Ordering::Greater
3851 }
3852 }
3853 (lval, rval) => {
3854 if order.desc {
3855 rval.cmp(lval)
3856 } else {
3857 lval.cmp(rval)
3858 }
3859 }
3860 };
3861 if cmp != Ordering::Equal {
3862 return cmp;
3863 }
3864 }
3865 tiebreaker()
3866}
3867
3868/// Describe a window frame, e.g. `RANGE UNBOUNDED PRECEDING` or
3869/// `ROWS BETWEEN 5 PRECEDING AND CURRENT ROW`.
3870///
3871/// Window frames define a subset of the partition , and only a subset of
3872/// window functions make use of the window frame.
3873#[derive(
3874 Debug,
3875 Clone,
3876 Eq,
3877 PartialEq,
3878 Ord,
3879 PartialOrd,
3880 Serialize,
3881 Deserialize,
3882 Hash,
3883 MzReflect
3884)]
3885pub struct WindowFrame {
3886 /// ROWS, RANGE or GROUPS
3887 pub units: WindowFrameUnits,
3888 /// Where the frame starts
3889 pub start_bound: WindowFrameBound,
3890 /// Where the frame ends
3891 pub end_bound: WindowFrameBound,
3892}
3893
3894impl Display for WindowFrame {
3895 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
3896 write!(
3897 f,
3898 "{} between {} and {}",
3899 self.units, self.start_bound, self.end_bound
3900 )
3901 }
3902}
3903
3904impl WindowFrame {
3905 /// Return the default window frame used when one is not explicitly defined
3906 pub fn default() -> Self {
3907 WindowFrame {
3908 units: WindowFrameUnits::Range,
3909 start_bound: WindowFrameBound::UnboundedPreceding,
3910 end_bound: WindowFrameBound::CurrentRow,
3911 }
3912 }
3913
3914 fn includes_current_row(&self) -> bool {
3915 use WindowFrameBound::*;
3916 match self.start_bound {
3917 UnboundedPreceding => match self.end_bound {
3918 UnboundedPreceding => false,
3919 OffsetPreceding(0) => true,
3920 OffsetPreceding(_) => false,
3921 CurrentRow => true,
3922 OffsetFollowing(_) => true,
3923 UnboundedFollowing => true,
3924 },
3925 OffsetPreceding(0) => match self.end_bound {
3926 UnboundedPreceding => unreachable!(),
3927 OffsetPreceding(0) => true,
3928 // Any nonzero offsets here will create an empty window
3929 OffsetPreceding(_) => false,
3930 CurrentRow => true,
3931 OffsetFollowing(_) => true,
3932 UnboundedFollowing => true,
3933 },
3934 OffsetPreceding(_) => match self.end_bound {
3935 UnboundedPreceding => unreachable!(),
3936 // Window ends at the current row
3937 OffsetPreceding(0) => true,
3938 OffsetPreceding(_) => false,
3939 CurrentRow => true,
3940 OffsetFollowing(_) => true,
3941 UnboundedFollowing => true,
3942 },
3943 CurrentRow => true,
3944 OffsetFollowing(0) => match self.end_bound {
3945 UnboundedPreceding => unreachable!(),
3946 OffsetPreceding(_) => unreachable!(),
3947 CurrentRow => unreachable!(),
3948 OffsetFollowing(_) => true,
3949 UnboundedFollowing => true,
3950 },
3951 OffsetFollowing(_) => match self.end_bound {
3952 UnboundedPreceding => unreachable!(),
3953 OffsetPreceding(_) => unreachable!(),
3954 CurrentRow => unreachable!(),
3955 OffsetFollowing(_) => false,
3956 UnboundedFollowing => false,
3957 },
3958 UnboundedFollowing => false,
3959 }
3960 }
3961}
3962
3963/// Describe how frame bounds are interpreted
3964#[derive(
3965 Debug,
3966 Clone,
3967 Eq,
3968 PartialEq,
3969 Ord,
3970 PartialOrd,
3971 Serialize,
3972 Deserialize,
3973 Hash,
3974 MzReflect
3975)]
3976pub enum WindowFrameUnits {
3977 /// Each row is treated as the unit of work for bounds
3978 Rows,
3979 /// Each peer group is treated as the unit of work for bounds,
3980 /// and offset-based bounds use the value of the ORDER BY expression
3981 Range,
3982 /// Each peer group is treated as the unit of work for bounds.
3983 /// Groups is currently not supported, and it is rejected during planning.
3984 Groups,
3985}
3986
3987impl Display for WindowFrameUnits {
3988 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
3989 match self {
3990 WindowFrameUnits::Rows => write!(f, "rows"),
3991 WindowFrameUnits::Range => write!(f, "range"),
3992 WindowFrameUnits::Groups => write!(f, "groups"),
3993 }
3994 }
3995}
3996
3997/// Specifies [WindowFrame]'s `start_bound` and `end_bound`
3998///
3999/// The order between frame bounds is significant, as Postgres enforces
4000/// some restrictions there.
4001#[derive(
4002 Debug,
4003 Clone,
4004 Serialize,
4005 Deserialize,
4006 PartialEq,
4007 Eq,
4008 Hash,
4009 MzReflect,
4010 PartialOrd,
4011 Ord
4012)]
4013pub enum WindowFrameBound {
4014 /// `UNBOUNDED PRECEDING`
4015 UnboundedPreceding,
4016 /// `<N> PRECEDING`
4017 OffsetPreceding(u64),
4018 /// `CURRENT ROW`
4019 CurrentRow,
4020 /// `<N> FOLLOWING`
4021 OffsetFollowing(u64),
4022 /// `UNBOUNDED FOLLOWING`.
4023 UnboundedFollowing,
4024}
4025
4026impl Display for WindowFrameBound {
4027 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
4028 match self {
4029 WindowFrameBound::UnboundedPreceding => write!(f, "unbounded preceding"),
4030 WindowFrameBound::OffsetPreceding(offset) => write!(f, "{} preceding", offset),
4031 WindowFrameBound::CurrentRow => write!(f, "current row"),
4032 WindowFrameBound::OffsetFollowing(offset) => write!(f, "{} following", offset),
4033 WindowFrameBound::UnboundedFollowing => write!(f, "unbounded following"),
4034 }
4035 }
4036}
4037
4038/// Maximum iterations for a LetRec.
4039#[derive(
4040 Debug,
4041 Clone,
4042 Copy,
4043 PartialEq,
4044 Eq,
4045 PartialOrd,
4046 Ord,
4047 Hash,
4048 Serialize,
4049 Deserialize
4050)]
4051pub struct LetRecLimit {
4052 /// Maximum number of iterations to evaluate.
4053 pub max_iters: NonZeroU64,
4054 /// Whether to throw an error when reaching the above limit.
4055 /// If true, we simply use the current contents of each Id as the final result.
4056 pub return_at_limit: bool,
4057}
4058
4059impl LetRecLimit {
4060 /// Compute the smallest limit from a Vec of `LetRecLimit`s.
4061 pub fn min_max_iter(limits: &Vec<Option<LetRecLimit>>) -> Option<u64> {
4062 limits
4063 .iter()
4064 .filter_map(|l| l.as_ref().map(|l| l.max_iters.get()))
4065 .min()
4066 }
4067
4068 /// The default value of `LetRecLimit::return_at_limit` when using the RECURSION LIMIT option of
4069 /// WMR without ERROR AT or RETURN AT.
4070 pub const RETURN_AT_LIMIT_DEFAULT: bool = false;
4071}
4072
4073impl Display for LetRecLimit {
4074 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
4075 write!(f, "[recursion_limit={}", self.max_iters)?;
4076 if self.return_at_limit != LetRecLimit::RETURN_AT_LIMIT_DEFAULT {
4077 write!(f, ", return_at_limit")?;
4078 }
4079 write!(f, "]")
4080 }
4081}
4082
4083/// For a global Get, this indicates whether we are going to read from Persist or from an index.
4084/// (See comment in MirRelationExpr::Get.)
4085#[derive(
4086 Clone,
4087 Debug,
4088 Eq,
4089 PartialEq,
4090 Ord,
4091 PartialOrd,
4092 Serialize,
4093 Deserialize,
4094 Hash
4095)]
4096pub enum AccessStrategy {
4097 /// It's either a local Get (a CTE), or unknown at the time.
4098 /// `prune_and_annotate_dataflow_index_imports` decides it for global Gets, and thus switches to
4099 /// one of the other variants.
4100 UnknownOrLocal,
4101 /// The Get will read from Persist.
4102 Persist,
4103 /// The Get will read from an index or indexes: (index id, how the index will be used).
4104 Index(Vec<(GlobalId, IndexUsageType)>),
4105 /// The Get will read a collection that is computed by the same dataflow, but in a different
4106 /// `BuildDesc` in `objects_to_build`.
4107 SameDataflow,
4108}
4109
4110#[cfg(test)]
4111mod tests {
4112 use std::num::NonZeroUsize;
4113
4114 use mz_repr::explain::text::text_string_at;
4115
4116 use crate::explain::HumanizedExplain;
4117
4118 use super::*;
4119
4120 #[mz_ore::test]
4121 fn test_row_set_finishing_as_text() {
4122 let finishing = RowSetFinishing {
4123 order_by: vec![ColumnOrder {
4124 column: 4,
4125 desc: true,
4126 nulls_last: true,
4127 }],
4128 limit: Some(NonNeg::try_from(7).unwrap()),
4129 offset: Default::default(),
4130 project: vec![1, 3, 4, 5],
4131 };
4132
4133 let mode = HumanizedExplain::new(false);
4134 let expr = mode.expr(&finishing, None);
4135
4136 let act = text_string_at(&expr, mz_ore::str::Indent::default);
4137
4138 let exp = {
4139 use mz_ore::fmt::FormatBuffer;
4140 let mut s = String::new();
4141 write!(&mut s, "Finish");
4142 write!(&mut s, " order_by=[#4 desc nulls_last]");
4143 write!(&mut s, " limit=7");
4144 write!(&mut s, " output=[#1, #3..=#5]");
4145 writeln!(&mut s, "");
4146 s
4147 };
4148
4149 assert_eq!(act, exp);
4150 }
4151
4152 #[mz_ore::test]
4153 fn test_row_set_finishing_incremental_max_returned_query_size() {
4154 let row = Row::pack_slice(&[Datum::String("hello")]);
4155 let row_size = u64::cast_from(row.data().len());
4156 let diff = NonZeroUsize::new(1).unwrap();
4157 let batch = RowCollection::new(vec![(row, diff)], &[]);
4158
4159 // Set max_returned_query_size to hold exactly 2 batches worth of rows.
4160 let mut finishing = RowSetFinishingIncremental::new(0, None, vec![0], Some(row_size * 2));
4161
4162 let max_result_size = u64::MAX;
4163
4164 let r = finishing.finish_incremental_inner(batch.clone(), max_result_size);
4165 assert!(r.is_ok());
4166 assert_eq!(finishing.remaining_max_returned_query_size, Some(row_size));
4167
4168 let r = finishing.finish_incremental_inner(batch.clone(), max_result_size);
4169 assert!(r.is_ok());
4170 assert_eq!(finishing.remaining_max_returned_query_size, Some(0));
4171
4172 let r = finishing.finish_incremental_inner(batch, max_result_size);
4173 assert!(r.unwrap_err().contains("total result exceeds max size"));
4174 }
4175}
4176
4177/// An iterator over AST structures, which calls out nodes in difference.
4178///
4179/// The iterators visit two ASTs in tandem, continuing as long as the AST node data matches,
4180/// and yielding an output pair as soon as the AST nodes do not match. Their intent is to call
4181/// attention to the moments in the ASTs where they differ, and incidentally a stack-free way
4182/// to compare two ASTs.
4183mod structured_diff {
4184
4185 use super::MirRelationExpr;
4186 use itertools::Itertools;
4187
4188 /// An iterator over structured differences between two `MirRelationExpr` instances.
4189 pub struct MreDiff<'a> {
4190 /// Pairs of expressions that must still be compared.
4191 todo: Vec<(&'a MirRelationExpr, &'a MirRelationExpr)>,
4192 }
4193
4194 impl<'a> MreDiff<'a> {
4195 /// Create a new `MirRelationExpr` structured difference.
4196 pub fn new(expr1: &'a MirRelationExpr, expr2: &'a MirRelationExpr) -> Self {
4197 MreDiff {
4198 todo: vec![(expr1, expr2)],
4199 }
4200 }
4201 }
4202
4203 impl<'a> Iterator for MreDiff<'a> {
4204 // Pairs of expressions that do not match.
4205 type Item = (&'a MirRelationExpr, &'a MirRelationExpr);
4206
4207 fn next(&mut self) -> Option<Self::Item> {
4208 while let Some((expr1, expr2)) = self.todo.pop() {
4209 match (expr1, expr2) {
4210 (
4211 MirRelationExpr::Constant {
4212 rows: rows1,
4213 typ: typ1,
4214 },
4215 MirRelationExpr::Constant {
4216 rows: rows2,
4217 typ: typ2,
4218 },
4219 ) => {
4220 if rows1 != rows2 || typ1 != typ2 {
4221 return Some((expr1, expr2));
4222 }
4223 }
4224 (
4225 MirRelationExpr::Get {
4226 id: id1,
4227 typ: typ1,
4228 access_strategy: as1,
4229 },
4230 MirRelationExpr::Get {
4231 id: id2,
4232 typ: typ2,
4233 access_strategy: as2,
4234 },
4235 ) => {
4236 if id1 != id2 || typ1 != typ2 || as1 != as2 {
4237 return Some((expr1, expr2));
4238 }
4239 }
4240 (
4241 MirRelationExpr::Let {
4242 id: id1,
4243 body: body1,
4244 value: value1,
4245 },
4246 MirRelationExpr::Let {
4247 id: id2,
4248 body: body2,
4249 value: value2,
4250 },
4251 ) => {
4252 if id1 != id2 {
4253 return Some((expr1, expr2));
4254 } else {
4255 self.todo.push((body1, body2));
4256 self.todo.push((value1, value2));
4257 }
4258 }
4259 (
4260 MirRelationExpr::LetRec {
4261 ids: ids1,
4262 body: body1,
4263 values: values1,
4264 limits: limits1,
4265 },
4266 MirRelationExpr::LetRec {
4267 ids: ids2,
4268 body: body2,
4269 values: values2,
4270 limits: limits2,
4271 },
4272 ) => {
4273 if ids1 != ids2 || values1.len() != values2.len() || limits1 != limits2 {
4274 return Some((expr1, expr2));
4275 } else {
4276 self.todo.push((body1, body2));
4277 self.todo.extend(values1.iter().zip_eq(values2.iter()));
4278 }
4279 }
4280 (
4281 MirRelationExpr::Project {
4282 outputs: outputs1,
4283 input: input1,
4284 },
4285 MirRelationExpr::Project {
4286 outputs: outputs2,
4287 input: input2,
4288 },
4289 ) => {
4290 if outputs1 != outputs2 {
4291 return Some((expr1, expr2));
4292 } else {
4293 self.todo.push((input1, input2));
4294 }
4295 }
4296 (
4297 MirRelationExpr::Map {
4298 scalars: scalars1,
4299 input: input1,
4300 },
4301 MirRelationExpr::Map {
4302 scalars: scalars2,
4303 input: input2,
4304 },
4305 ) => {
4306 if scalars1 != scalars2 {
4307 return Some((expr1, expr2));
4308 } else {
4309 self.todo.push((input1, input2));
4310 }
4311 }
4312 (
4313 MirRelationExpr::Filter {
4314 predicates: predicates1,
4315 input: input1,
4316 },
4317 MirRelationExpr::Filter {
4318 predicates: predicates2,
4319 input: input2,
4320 },
4321 ) => {
4322 if predicates1 != predicates2 {
4323 return Some((expr1, expr2));
4324 } else {
4325 self.todo.push((input1, input2));
4326 }
4327 }
4328 (
4329 MirRelationExpr::FlatMap {
4330 input: input1,
4331 func: func1,
4332 exprs: exprs1,
4333 },
4334 MirRelationExpr::FlatMap {
4335 input: input2,
4336 func: func2,
4337 exprs: exprs2,
4338 },
4339 ) => {
4340 if func1 != func2 || exprs1 != exprs2 {
4341 return Some((expr1, expr2));
4342 } else {
4343 self.todo.push((input1, input2));
4344 }
4345 }
4346 (
4347 MirRelationExpr::Join {
4348 inputs: inputs1,
4349 equivalences: eq1,
4350 implementation: impl1,
4351 },
4352 MirRelationExpr::Join {
4353 inputs: inputs2,
4354 equivalences: eq2,
4355 implementation: impl2,
4356 },
4357 ) => {
4358 if inputs1.len() != inputs2.len() || eq1 != eq2 || impl1 != impl2 {
4359 return Some((expr1, expr2));
4360 } else {
4361 self.todo.extend(inputs1.iter().zip_eq(inputs2.iter()));
4362 }
4363 }
4364 (
4365 MirRelationExpr::Reduce {
4366 aggregates: aggregates1,
4367 input: inputs1,
4368 group_key: gk1,
4369 monotonic: m1,
4370 expected_group_size: egs1,
4371 },
4372 MirRelationExpr::Reduce {
4373 aggregates: aggregates2,
4374 input: inputs2,
4375 group_key: gk2,
4376 monotonic: m2,
4377 expected_group_size: egs2,
4378 },
4379 ) => {
4380 if aggregates1 != aggregates2 || gk1 != gk2 || m1 != m2 || egs1 != egs2 {
4381 return Some((expr1, expr2));
4382 } else {
4383 self.todo.push((inputs1, inputs2));
4384 }
4385 }
4386 (
4387 MirRelationExpr::TopK {
4388 group_key: gk1,
4389 order_key: order1,
4390 input: input1,
4391 limit: l1,
4392 offset: o1,
4393 monotonic: m1,
4394 expected_group_size: egs1,
4395 },
4396 MirRelationExpr::TopK {
4397 group_key: gk2,
4398 order_key: order2,
4399 input: input2,
4400 limit: l2,
4401 offset: o2,
4402 monotonic: m2,
4403 expected_group_size: egs2,
4404 },
4405 ) => {
4406 if order1 != order2
4407 || gk1 != gk2
4408 || l1 != l2
4409 || o1 != o2
4410 || m1 != m2
4411 || egs1 != egs2
4412 {
4413 return Some((expr1, expr2));
4414 } else {
4415 self.todo.push((input1, input2));
4416 }
4417 }
4418 (
4419 MirRelationExpr::Negate { input: input1 },
4420 MirRelationExpr::Negate { input: input2 },
4421 ) => {
4422 self.todo.push((input1, input2));
4423 }
4424 (
4425 MirRelationExpr::Threshold { input: input1 },
4426 MirRelationExpr::Threshold { input: input2 },
4427 ) => {
4428 self.todo.push((input1, input2));
4429 }
4430 (
4431 MirRelationExpr::Union {
4432 base: base1,
4433 inputs: inputs1,
4434 },
4435 MirRelationExpr::Union {
4436 base: base2,
4437 inputs: inputs2,
4438 },
4439 ) => {
4440 if inputs1.len() != inputs2.len() {
4441 return Some((expr1, expr2));
4442 } else {
4443 self.todo.push((base1, base2));
4444 self.todo.extend(inputs1.iter().zip_eq(inputs2.iter()));
4445 }
4446 }
4447 (
4448 MirRelationExpr::ArrangeBy {
4449 keys: keys1,
4450 input: input1,
4451 },
4452 MirRelationExpr::ArrangeBy {
4453 keys: keys2,
4454 input: input2,
4455 },
4456 ) => {
4457 if keys1 != keys2 {
4458 return Some((expr1, expr2));
4459 } else {
4460 self.todo.push((input1, input2));
4461 }
4462 }
4463 _ => {
4464 return Some((expr1, expr2));
4465 }
4466 }
4467 }
4468 None
4469 }
4470 }
4471}