mz_expr/relation.rs
1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10#![warn(missing_docs)]
11
12use std::cmp::{Ordering, max};
13use std::collections::{BTreeMap, BTreeSet};
14use std::fmt;
15use std::fmt::{Display, Formatter};
16use std::hash::{DefaultHasher, Hash, Hasher};
17use std::num::NonZeroU64;
18use std::time::Instant;
19
20use bytesize::ByteSize;
21use differential_dataflow::containers::{Columnation, CopyRegion};
22use itertools::Itertools;
23use mz_lowertest::MzReflect;
24use mz_ore::cast::CastFrom;
25use mz_ore::collections::CollectionExt;
26use mz_ore::id_gen::IdGen;
27use mz_ore::metrics::Histogram;
28use mz_ore::num::NonNeg;
29use mz_ore::soft_assert_no_log;
30use mz_ore::stack::RecursionLimitError;
31use mz_ore::str::Indent;
32use mz_proto::{IntoRustIfSome, ProtoType, RustType, TryFromProtoError};
33use mz_repr::adt::numeric::NumericMaxScale;
34use mz_repr::explain::text::text_string_at;
35use mz_repr::explain::{
36 DummyHumanizer, ExplainConfig, ExprHumanizer, IndexUsageType, PlanRenderingContext,
37};
38use mz_repr::{
39 ColumnName, ColumnType, Datum, Diff, GlobalId, IntoRowIterator, RelationType, Row, RowIterator,
40 ScalarType,
41};
42use proptest::prelude::{Arbitrary, BoxedStrategy, any};
43use proptest::strategy::{Strategy, Union};
44use proptest_derive::Arbitrary;
45use serde::{Deserialize, Serialize};
46
47use crate::Id::Local;
48use crate::explain::{HumanizedExpr, HumanizerMode};
49use crate::relation::func::{AggregateFunc, LagLeadType, TableFunc};
50use crate::row::{RowCollection, SortedRowCollectionIter};
51use crate::visit::{Visit, VisitChildren};
52use crate::{
53 EvalError, FilterCharacteristics, Id, LocalId, MirScalarExpr, UnaryFunc, VariadicFunc,
54 func as scalar_func,
55};
56
57pub mod canonicalize;
58pub mod func;
59pub mod join_input_mapper;
60
61include!(concat!(env!("OUT_DIR"), "/mz_expr.relation.rs"));
62
63/// A recursion limit to be used for stack-safe traversals of [`MirRelationExpr`] trees.
64///
65/// The recursion limit must be large enough to accommodate for the linear representation
66/// of some pathological but frequently occurring query fragments.
67///
68/// For example, in MIR we could have long chains of
69/// - (1) `Let` bindings,
70/// - (2) `CallBinary` calls with associative functions such as `+`
71///
72/// Until we fix those, we need to stick with the larger recursion limit.
73pub const RECURSION_LIMIT: usize = 2048;
74
75/// A trait for types that describe how to build a collection.
76pub trait CollectionPlan {
77 /// Collects the set of global identifiers from dataflows referenced in Get.
78 fn depends_on_into(&self, out: &mut BTreeSet<GlobalId>);
79
80 /// Returns the set of global identifiers from dataflows referenced in Get.
81 ///
82 /// See [`CollectionPlan::depends_on_into`] to reuse an existing `BTreeSet`.
83 fn depends_on(&self) -> BTreeSet<GlobalId> {
84 let mut out = BTreeSet::new();
85 self.depends_on_into(&mut out);
86 out
87 }
88}
89
90/// An abstract syntax tree which defines a collection.
91///
92/// The AST is meant to reflect the capabilities of the `differential_dataflow::Collection` type,
93/// written generically enough to avoid run-time compilation work.
94///
95/// `derived_hash_with_manual_eq` was complaining for the wrong reason: This lint exists because
96/// it's bad when `Eq` doesn't agree with `Hash`, which is often quite likely if one of them is
97/// implemented manually. However, our manual implementation of `Eq` _will_ agree with the derived
98/// one. This is because the reason for the manual implementation is not to change the semantics
99/// from the derived one, but to avoid stack overflows.
100#[allow(clippy::derived_hash_with_manual_eq)]
101#[derive(Clone, Debug, Ord, PartialOrd, Serialize, Deserialize, MzReflect, Hash)]
102pub enum MirRelationExpr {
103 /// A constant relation containing specified rows.
104 ///
105 /// The runtime memory footprint of this operator is zero.
106 ///
107 /// When you would like to pattern match on this, consider using `MirRelationExpr::as_const`
108 /// instead, which looks behind `ArrangeBy`s. You might want this matching behavior because
109 /// constant folding doesn't remove `ArrangeBy`s.
110 Constant {
111 /// Rows of the constant collection and their multiplicities.
112 rows: Result<Vec<(Row, Diff)>, EvalError>,
113 /// Schema of the collection.
114 typ: RelationType,
115 },
116 /// Get an existing dataflow.
117 ///
118 /// The runtime memory footprint of this operator is zero.
119 Get {
120 /// The identifier for the collection to load.
121 #[mzreflect(ignore)]
122 id: Id,
123 /// Schema of the collection.
124 typ: RelationType,
125 /// If this is a global Get, this will indicate whether we are going to read from Persist or
126 /// from an index, or from a different object in `objects_to_build`. If it's an index, then
127 /// how downstream dataflow operations will use this index is also recorded. This is filled
128 /// by `prune_and_annotate_dataflow_index_imports`. Note that this is not used by the
129 /// lowering to LIR, but is used only by EXPLAIN.
130 #[mzreflect(ignore)]
131 access_strategy: AccessStrategy,
132 },
133 /// Introduce a temporary dataflow.
134 ///
135 /// The runtime memory footprint of this operator is zero.
136 Let {
137 /// The identifier to be used in `Get` variants to retrieve `value`.
138 #[mzreflect(ignore)]
139 id: LocalId,
140 /// The collection to be bound to `id`.
141 value: Box<MirRelationExpr>,
142 /// The result of the `Let`, evaluated with `id` bound to `value`.
143 body: Box<MirRelationExpr>,
144 },
145 /// Introduce mutually recursive bindings.
146 ///
147 /// Each `LocalId` is immediately bound to an initially empty collection
148 /// with the type of its corresponding `MirRelationExpr`. Repeatedly, each
149 /// binding is evaluated using the current contents of each other binding,
150 /// and is refreshed to contain the new evaluation. This process continues
151 /// through all bindings, and repeats as long as changes continue to occur.
152 ///
153 /// The resulting value of the expression is `body` evaluated once in the
154 /// context of the final iterates.
155 ///
156 /// A zero-binding instance can be replaced by `body`.
157 /// A single-binding instance is equivalent to `MirRelationExpr::Let`.
158 ///
159 /// The runtime memory footprint of this operator is zero.
160 LetRec {
161 /// The identifiers to be used in `Get` variants to retrieve each `value`.
162 #[mzreflect(ignore)]
163 ids: Vec<LocalId>,
164 /// The collections to be bound to each `id`.
165 values: Vec<MirRelationExpr>,
166 /// Maximum number of iterations, after which we should artificially force a fixpoint.
167 /// (Whether we error or just stop is configured by `LetRecLimit::return_at_limit`.)
168 /// The per-`LetRec` limit that the user specified is initially copied to each binding to
169 /// accommodate slicing and merging of `LetRec`s in MIR transforms (e.g., `NormalizeLets`).
170 #[mzreflect(ignore)]
171 limits: Vec<Option<LetRecLimit>>,
172 /// The result of the `Let`, evaluated with `id` bound to `value`.
173 body: Box<MirRelationExpr>,
174 },
175 /// Project out some columns from a dataflow
176 ///
177 /// The runtime memory footprint of this operator is zero.
178 Project {
179 /// The source collection.
180 input: Box<MirRelationExpr>,
181 /// Indices of columns to retain.
182 outputs: Vec<usize>,
183 },
184 /// Append new columns to a dataflow
185 ///
186 /// The runtime memory footprint of this operator is zero.
187 Map {
188 /// The source collection.
189 input: Box<MirRelationExpr>,
190 /// Expressions which determine values to append to each row.
191 /// An expression may refer to columns in `input` or
192 /// expressions defined earlier in the vector
193 scalars: Vec<MirScalarExpr>,
194 },
195 /// Like Map, but yields zero-or-more output rows per input row
196 ///
197 /// The runtime memory footprint of this operator is zero.
198 FlatMap {
199 /// The source collection
200 input: Box<MirRelationExpr>,
201 /// The table func to apply
202 func: TableFunc,
203 /// The argument to the table func
204 exprs: Vec<MirScalarExpr>,
205 },
206 /// Keep rows from a dataflow where all the predicates are true
207 ///
208 /// The runtime memory footprint of this operator is zero.
209 Filter {
210 /// The source collection.
211 input: Box<MirRelationExpr>,
212 /// Predicates, each of which must be true.
213 predicates: Vec<MirScalarExpr>,
214 },
215 /// Join several collections, where some columns must be equal.
216 ///
217 /// For further details consult the documentation for [`MirRelationExpr::join`].
218 ///
219 /// The runtime memory footprint of this operator can be proportional to
220 /// the sizes of all inputs and the size of all joins of prefixes.
221 /// This may be reduced due to arrangements available at rendering time.
222 Join {
223 /// A sequence of input relations.
224 inputs: Vec<MirRelationExpr>,
225 /// A sequence of equivalence classes of expressions on the cross product of inputs.
226 ///
227 /// Each equivalence class is a list of scalar expressions, where for each class the
228 /// intended interpretation is that all evaluated expressions should be equal.
229 ///
230 /// Each scalar expression is to be evaluated over the cross-product of all records
231 /// from all inputs. In many cases this may just be column selection from specific
232 /// inputs, but more general cases exist (e.g. complex functions of multiple columns
233 /// from multiple inputs, or just constant literals).
234 equivalences: Vec<Vec<MirScalarExpr>>,
235 /// Join implementation information.
236 #[serde(default)]
237 implementation: JoinImplementation,
238 },
239 /// Group a dataflow by some columns and aggregate over each group
240 ///
241 /// The runtime memory footprint of this operator is at most proportional to the
242 /// number of distinct records in the input and output. The actual requirements
243 /// can be less: the number of distinct inputs to each aggregate, summed across
244 /// each aggregate, plus the output size. For more details consult the code that
245 /// builds the associated dataflow.
246 Reduce {
247 /// The source collection.
248 input: Box<MirRelationExpr>,
249 /// Column indices used to form groups.
250 group_key: Vec<MirScalarExpr>,
251 /// Expressions which determine values to append to each row, after the group keys.
252 aggregates: Vec<AggregateExpr>,
253 /// True iff the input is known to monotonically increase (only addition of records).
254 #[serde(default)]
255 monotonic: bool,
256 /// User hint: expected number of values per group key. Used to optimize physical rendering.
257 #[serde(default)]
258 expected_group_size: Option<u64>,
259 },
260 /// Groups and orders within each group, limiting output.
261 ///
262 /// The runtime memory footprint of this operator is proportional to its input and output.
263 TopK {
264 /// The source collection.
265 input: Box<MirRelationExpr>,
266 /// Column indices used to form groups.
267 group_key: Vec<usize>,
268 /// Column indices used to order rows within groups.
269 order_key: Vec<ColumnOrder>,
270 /// Number of records to retain
271 #[serde(default)]
272 limit: Option<MirScalarExpr>,
273 /// Number of records to skip
274 #[serde(default)]
275 offset: usize,
276 /// True iff the input is known to monotonically increase (only addition of records).
277 #[serde(default)]
278 monotonic: bool,
279 /// User-supplied hint: how many rows will have the same group key.
280 #[serde(default)]
281 expected_group_size: Option<u64>,
282 },
283 /// Return a dataflow where the row counts are negated
284 ///
285 /// The runtime memory footprint of this operator is zero.
286 Negate {
287 /// The source collection.
288 input: Box<MirRelationExpr>,
289 },
290 /// Keep rows from a dataflow where the row counts are positive
291 ///
292 /// The runtime memory footprint of this operator is proportional to its input and output.
293 Threshold {
294 /// The source collection.
295 input: Box<MirRelationExpr>,
296 },
297 /// Adds the frequencies of elements in contained sets.
298 ///
299 /// The runtime memory footprint of this operator is zero.
300 Union {
301 /// A source collection.
302 base: Box<MirRelationExpr>,
303 /// Source collections to union.
304 inputs: Vec<MirRelationExpr>,
305 },
306 /// Technically a no-op. Used to render an index. Will be used to optimize queries
307 /// on finer grain. Each `keys` item represents a different index that should be
308 /// produced from the `keys`.
309 ///
310 /// The runtime memory footprint of this operator is proportional to its input.
311 ArrangeBy {
312 /// The source collection
313 input: Box<MirRelationExpr>,
314 /// Columns to arrange `input` by, in order of decreasing primacy
315 keys: Vec<Vec<MirScalarExpr>>,
316 },
317}
318
319impl PartialEq for MirRelationExpr {
320 fn eq(&self, other: &Self) -> bool {
321 // Capture the result and test it wrt `Ord` implementation in test environments.
322 let result = structured_diff::MreDiff::new(self, other).next().is_none();
323 mz_ore::soft_assert_eq_no_log!(result, self.cmp(other) == Ordering::Equal);
324 result
325 }
326}
327impl Eq for MirRelationExpr {}
328
329impl Arbitrary for MirRelationExpr {
330 type Parameters = ();
331
332 fn arbitrary_with((): Self::Parameters) -> Self::Strategy {
333 const VEC_LEN: usize = 2;
334
335 // A strategy for generating the leaf cases.
336 let leaf = Union::new([
337 (
338 any::<Result<Vec<(Row, Diff)>, EvalError>>(),
339 any::<RelationType>(),
340 )
341 .prop_map(|(rows, typ)| MirRelationExpr::Constant { rows, typ })
342 .boxed(),
343 (any::<Id>(), any::<RelationType>(), any::<AccessStrategy>())
344 .prop_map(|(id, typ, access_strategy)| MirRelationExpr::Get {
345 id,
346 typ,
347 access_strategy,
348 })
349 .boxed(),
350 ])
351 .boxed();
352
353 leaf.prop_recursive(2, 2, 2, |inner| {
354 Union::new([
355 // Let
356 (any::<LocalId>(), inner.clone(), inner.clone())
357 .prop_map(|(id, value, body)| MirRelationExpr::Let {
358 id,
359 value: Box::new(value),
360 body: Box::new(body),
361 })
362 .boxed(),
363 // LetRec
364 (
365 any::<Vec<LocalId>>(),
366 proptest::collection::vec(inner.clone(), 1..VEC_LEN),
367 any::<Vec<Option<LetRecLimit>>>(),
368 inner.clone(),
369 )
370 .prop_map(|(ids, values, limits, body)| MirRelationExpr::LetRec {
371 ids,
372 values,
373 limits,
374 body: Box::new(body),
375 })
376 .boxed(),
377 // Project
378 (inner.clone(), any::<Vec<usize>>())
379 .prop_map(|(input, outputs)| MirRelationExpr::Project {
380 input: Box::new(input),
381 outputs,
382 })
383 .boxed(),
384 // Map
385 (inner.clone(), any::<Vec<MirScalarExpr>>())
386 .prop_map(|(input, scalars)| MirRelationExpr::Map {
387 input: Box::new(input),
388 scalars,
389 })
390 .boxed(),
391 // FlatMap
392 (
393 inner.clone(),
394 any::<TableFunc>(),
395 any::<Vec<MirScalarExpr>>(),
396 )
397 .prop_map(|(input, func, exprs)| MirRelationExpr::FlatMap {
398 input: Box::new(input),
399 func,
400 exprs,
401 })
402 .boxed(),
403 // Filter
404 (inner.clone(), any::<Vec<MirScalarExpr>>())
405 .prop_map(|(input, predicates)| MirRelationExpr::Filter {
406 input: Box::new(input),
407 predicates,
408 })
409 .boxed(),
410 // Join
411 (
412 proptest::collection::vec(inner.clone(), 1..VEC_LEN),
413 any::<Vec<Vec<MirScalarExpr>>>(),
414 any::<JoinImplementation>(),
415 )
416 .prop_map(
417 |(inputs, equivalences, implementation)| MirRelationExpr::Join {
418 inputs,
419 equivalences,
420 implementation,
421 },
422 )
423 .boxed(),
424 // Reduce
425 (
426 inner.clone(),
427 any::<Vec<MirScalarExpr>>(),
428 any::<Vec<AggregateExpr>>(),
429 any::<bool>(),
430 any::<Option<u64>>(),
431 )
432 .prop_map(
433 |(input, group_key, aggregates, monotonic, expected_group_size)| {
434 MirRelationExpr::Reduce {
435 input: Box::new(input),
436 group_key,
437 aggregates,
438 monotonic,
439 expected_group_size,
440 }
441 },
442 )
443 .boxed(),
444 // TopK
445 (
446 inner.clone(),
447 any::<Vec<usize>>(),
448 any::<Vec<ColumnOrder>>(),
449 any::<Option<MirScalarExpr>>(),
450 any::<usize>(),
451 any::<bool>(),
452 any::<Option<u64>>(),
453 )
454 .prop_map(
455 |(
456 input,
457 group_key,
458 order_key,
459 limit,
460 offset,
461 monotonic,
462 expected_group_size,
463 )| MirRelationExpr::TopK {
464 input: Box::new(input),
465 group_key,
466 order_key,
467 limit,
468 offset,
469 monotonic,
470 expected_group_size,
471 },
472 )
473 .boxed(),
474 // Negate
475 inner
476 .clone()
477 .prop_map(|input| MirRelationExpr::Negate {
478 input: Box::new(input),
479 })
480 .boxed(),
481 // Threshold
482 inner
483 .clone()
484 .prop_map(|input| MirRelationExpr::Threshold {
485 input: Box::new(input),
486 })
487 .boxed(),
488 // Union
489 (
490 inner.clone(),
491 proptest::collection::vec(inner.clone(), 1..VEC_LEN),
492 )
493 .prop_map(|(base, inputs)| MirRelationExpr::Union {
494 base: Box::new(base),
495 inputs,
496 })
497 .boxed(),
498 // ArrangeBy
499 (inner.clone(), any::<Vec<Vec<MirScalarExpr>>>())
500 .prop_map(|(input, keys)| MirRelationExpr::ArrangeBy {
501 input: Box::new(input),
502 keys,
503 })
504 .boxed(),
505 ])
506 })
507 .boxed()
508 }
509
510 type Strategy = BoxedStrategy<Self>;
511}
512
513impl MirRelationExpr {
514 /// Reports the schema of the relation.
515 ///
516 /// This method determines the type through recursive traversal of the
517 /// relation expression, drawing from the types of base collections.
518 /// As such, this is not an especially cheap method, and should be used
519 /// judiciously.
520 ///
521 /// The relation type is computed incrementally with a recursive post-order
522 /// traversal, that accumulates the input types for the relations yet to be
523 /// visited in `type_stack`.
524 pub fn typ(&self) -> RelationType {
525 let mut type_stack = Vec::new();
526 #[allow(deprecated)]
527 self.visit_pre_post_nolimit(
528 &mut |e: &MirRelationExpr| -> Option<Vec<&MirRelationExpr>> {
529 match &e {
530 MirRelationExpr::Let { body, .. } => {
531 // Do not traverse the value sub-graph, since it's not relevant for
532 // determining the relation type of Let operators.
533 Some(vec![&*body])
534 }
535 MirRelationExpr::LetRec { body, .. } => {
536 // Do not traverse the value sub-graph, since it's not relevant for
537 // determining the relation type of Let operators.
538 Some(vec![&*body])
539 }
540 _ => None,
541 }
542 },
543 &mut |e: &MirRelationExpr| {
544 match e {
545 MirRelationExpr::Let { .. } => {
546 let body_typ = type_stack.pop().unwrap();
547 // Insert a dummy relation type for the value, since `typ_with_input_types`
548 // won't look at it, but expects the relation type of the body to be second.
549 type_stack.push(RelationType::empty());
550 type_stack.push(body_typ);
551 }
552 MirRelationExpr::LetRec { values, .. } => {
553 let body_typ = type_stack.pop().unwrap();
554 // Insert dummy relation types for the values, since `typ_with_input_types`
555 // won't look at them, but expects the relation type of the body to be last.
556 type_stack
557 .extend(std::iter::repeat(RelationType::empty()).take(values.len()));
558 type_stack.push(body_typ);
559 }
560 _ => {}
561 }
562 let num_inputs = e.num_inputs();
563 let relation_type =
564 e.typ_with_input_types(&type_stack[type_stack.len() - num_inputs..]);
565 type_stack.truncate(type_stack.len() - num_inputs);
566 type_stack.push(relation_type);
567 },
568 );
569 assert_eq!(type_stack.len(), 1);
570 type_stack.pop().unwrap()
571 }
572
573 /// Reports the schema of the relation given the schema of the input relations.
574 ///
575 /// `input_types` is required to contain the schemas for the input relations of
576 /// the current relation in the same order as they are visited by `try_visit_children`
577 /// method, even though not all may be used for computing the schema of the
578 /// current relation. For example, `Let` expects two input types, one for the
579 /// value relation and one for the body, in that order, but only the one for the
580 /// body is used to determine the type of the `Let` relation.
581 ///
582 /// It is meant to be used during post-order traversals to compute relation
583 /// schemas incrementally.
584 pub fn typ_with_input_types(&self, input_types: &[RelationType]) -> RelationType {
585 let column_types = self.col_with_input_cols(input_types.iter().map(|i| &i.column_types));
586 let unique_keys = self.keys_with_input_keys(
587 input_types.iter().map(|i| i.arity()),
588 input_types.iter().map(|i| &i.keys),
589 );
590 RelationType::new(column_types).with_keys(unique_keys)
591 }
592
593 /// Reports the column types of the relation given the column types of the
594 /// input relations.
595 ///
596 /// This method delegates to `try_col_with_input_cols`, panicking if an `Err`
597 /// variant is returned.
598 pub fn col_with_input_cols<'a, I>(&self, input_types: I) -> Vec<ColumnType>
599 where
600 I: Iterator<Item = &'a Vec<ColumnType>>,
601 {
602 match self.try_col_with_input_cols(input_types) {
603 Ok(col_types) => col_types,
604 Err(err) => panic!("{err}"),
605 }
606 }
607
608 /// Reports the column types of the relation given the column types of the input relations.
609 ///
610 /// `input_types` is required to contain the column types for the input relations of
611 /// the current relation in the same order as they are visited by `try_visit_children`
612 /// method, even though not all may be used for computing the schema of the
613 /// current relation. For example, `Let` expects two input types, one for the
614 /// value relation and one for the body, in that order, but only the one for the
615 /// body is used to determine the type of the `Let` relation.
616 ///
617 /// It is meant to be used during post-order traversals to compute column types
618 /// incrementally.
619 pub fn try_col_with_input_cols<'a, I>(
620 &self,
621 mut input_types: I,
622 ) -> Result<Vec<ColumnType>, String>
623 where
624 I: Iterator<Item = &'a Vec<ColumnType>>,
625 {
626 use MirRelationExpr::*;
627
628 let col_types = match self {
629 Constant { rows, typ } => {
630 let mut col_types = typ.column_types.clone();
631 let mut seen_null = vec![false; typ.arity()];
632 if let Ok(rows) = rows {
633 for (row, _diff) in rows {
634 for (datum, i) in row.iter().zip_eq(0..typ.arity()) {
635 if datum.is_null() {
636 seen_null[i] = true;
637 }
638 }
639 }
640 }
641 for (&seen_null, i) in seen_null.iter().zip_eq(0..typ.arity()) {
642 if !seen_null {
643 col_types[i].nullable = false;
644 } else {
645 assert!(col_types[i].nullable);
646 }
647 }
648 col_types
649 }
650 Get { typ, .. } => typ.column_types.clone(),
651 Project { outputs, .. } => {
652 let input = input_types.next().unwrap();
653 outputs.iter().map(|&i| input[i].clone()).collect()
654 }
655 Map { scalars, .. } => {
656 let mut result = input_types.next().unwrap().clone();
657 for scalar in scalars.iter() {
658 result.push(scalar.typ(&result))
659 }
660 result
661 }
662 FlatMap { func, .. } => {
663 let mut result = input_types.next().unwrap().clone();
664 result.extend(func.output_type().column_types);
665 result
666 }
667 Filter { predicates, .. } => {
668 let mut result = input_types.next().unwrap().clone();
669
670 // Set as nonnull any columns where null values would cause
671 // any predicate to evaluate to null.
672 for column in non_nullable_columns(predicates) {
673 result[column].nullable = false;
674 }
675 result
676 }
677 Join { equivalences, .. } => {
678 // Concatenate input column types
679 let mut types = input_types.flat_map(|cols| cols.to_owned()).collect_vec();
680 // In an equivalence class, if any column is non-null, then make all non-null
681 for equivalence in equivalences {
682 let col_inds = equivalence
683 .iter()
684 .filter_map(|expr| match expr {
685 MirScalarExpr::Column(col, _name) => Some(*col),
686 _ => None,
687 })
688 .collect_vec();
689 if col_inds.iter().any(|i| !types.get(*i).unwrap().nullable) {
690 for i in col_inds {
691 types.get_mut(i).unwrap().nullable = false;
692 }
693 }
694 }
695 types
696 }
697 Reduce {
698 group_key,
699 aggregates,
700 ..
701 } => {
702 let input = input_types.next().unwrap();
703 group_key
704 .iter()
705 .map(|e| e.typ(input))
706 .chain(aggregates.iter().map(|agg| agg.typ(input)))
707 .collect()
708 }
709 TopK { .. } | Negate { .. } | Threshold { .. } | ArrangeBy { .. } => {
710 input_types.next().unwrap().clone()
711 }
712 Let { .. } => {
713 // skip over the input types for `value`.
714 input_types.nth(1).unwrap().clone()
715 }
716 LetRec { values, .. } => {
717 // skip over the input types for `values`.
718 input_types.nth(values.len()).unwrap().clone()
719 }
720 Union { .. } => {
721 let mut result = input_types.next().unwrap().clone();
722 for input_col_types in input_types {
723 for (base_col, col) in result.iter_mut().zip_eq(input_col_types) {
724 *base_col = base_col
725 .union(col)
726 .map_err(|e| format!("{}\nin plan:\n{}", e, self.pretty()))?;
727 }
728 }
729 result
730 }
731 };
732
733 Ok(col_types)
734 }
735
736 /// Reports the unique keys of the relation given the arities and the unique
737 /// keys of the input relations.
738 ///
739 /// `input_arities` and `input_keys` are required to contain the
740 /// corresponding info for the input relations of
741 /// the current relation in the same order as they are visited by `try_visit_children`
742 /// method, even though not all may be used for computing the schema of the
743 /// current relation. For example, `Let` expects two input types, one for the
744 /// value relation and one for the body, in that order, but only the one for the
745 /// body is used to determine the type of the `Let` relation.
746 ///
747 /// It is meant to be used during post-order traversals to compute unique keys
748 /// incrementally.
749 pub fn keys_with_input_keys<'a, I, J>(
750 &self,
751 mut input_arities: I,
752 mut input_keys: J,
753 ) -> Vec<Vec<usize>>
754 where
755 I: Iterator<Item = usize>,
756 J: Iterator<Item = &'a Vec<Vec<usize>>>,
757 {
758 use MirRelationExpr::*;
759
760 let mut keys = match self {
761 Constant {
762 rows: Ok(rows),
763 typ,
764 } => {
765 let n_cols = typ.arity();
766 // If the `i`th entry is `Some`, then we have not yet observed non-uniqueness in the `i`th column.
767 let mut unique_values_per_col = vec![Some(BTreeSet::<Datum>::default()); n_cols];
768 for (row, diff) in rows {
769 for (i, datum) in row.iter().enumerate() {
770 if datum != Datum::Dummy {
771 if let Some(unique_vals) = &mut unique_values_per_col[i] {
772 let is_dupe = *diff != Diff::ONE || !unique_vals.insert(datum);
773 if is_dupe {
774 unique_values_per_col[i] = None;
775 }
776 }
777 }
778 }
779 }
780 if rows.len() == 0 || (rows.len() == 1 && rows[0].1 == Diff::ONE) {
781 vec![vec![]]
782 } else {
783 // XXX - Multi-column keys are not detected.
784 typ.keys
785 .iter()
786 .cloned()
787 .chain(
788 unique_values_per_col
789 .into_iter()
790 .enumerate()
791 .filter(|(_idx, unique_vals)| unique_vals.is_some())
792 .map(|(idx, _)| vec![idx]),
793 )
794 .collect()
795 }
796 }
797 Constant { rows: Err(_), typ } | Get { typ, .. } => typ.keys.clone(),
798 Threshold { .. } | ArrangeBy { .. } => input_keys.next().unwrap().clone(),
799 Let { .. } => {
800 // skip over the unique keys for value
801 input_keys.nth(1).unwrap().clone()
802 }
803 LetRec { values, .. } => {
804 // skip over the unique keys for value
805 input_keys.nth(values.len()).unwrap().clone()
806 }
807 Project { outputs, .. } => {
808 let input = input_keys.next().unwrap();
809 input
810 .iter()
811 .filter_map(|key_set| {
812 if key_set.iter().all(|k| outputs.contains(k)) {
813 Some(
814 key_set
815 .iter()
816 .map(|c| outputs.iter().position(|o| o == c).unwrap())
817 .collect(),
818 )
819 } else {
820 None
821 }
822 })
823 .collect()
824 }
825 Map { scalars, .. } => {
826 let mut remappings = Vec::new();
827 let arity = input_arities.next().unwrap();
828 for (column, scalar) in scalars.iter().enumerate() {
829 // assess whether the scalar preserves uniqueness,
830 // and could participate in a key!
831
832 fn uniqueness(expr: &MirScalarExpr) -> Option<usize> {
833 match expr {
834 MirScalarExpr::CallUnary { func, expr } => {
835 if func.preserves_uniqueness() {
836 uniqueness(expr)
837 } else {
838 None
839 }
840 }
841 MirScalarExpr::Column(c, _name) => Some(*c),
842 _ => None,
843 }
844 }
845
846 if let Some(c) = uniqueness(scalar) {
847 remappings.push((c, column + arity));
848 }
849 }
850
851 let mut result = input_keys.next().unwrap().clone();
852 let mut new_keys = Vec::new();
853 // Any column in `remappings` could be replaced in a key
854 // by the corresponding c. This could lead to combinatorial
855 // explosion using our current representation, so we wont
856 // do that. Instead, we'll handle the case of one remapping.
857 if remappings.len() == 1 {
858 let (old, new) = remappings.pop().unwrap();
859 for key in &result {
860 if key.contains(&old) {
861 let mut new_key: Vec<usize> =
862 key.iter().cloned().filter(|k| k != &old).collect();
863 new_key.push(new);
864 new_key.sort_unstable();
865 new_keys.push(new_key);
866 }
867 }
868 result.append(&mut new_keys);
869 }
870 result
871 }
872 FlatMap { .. } => {
873 // FlatMap can add duplicate rows, so input keys are no longer
874 // valid
875 vec![]
876 }
877 Negate { .. } => {
878 // Although negate may have distinct records for each key,
879 // the multiplicity is -1 rather than 1. This breaks many
880 // of the optimization uses of "keys".
881 vec![]
882 }
883 Filter { predicates, .. } => {
884 // A filter inherits the keys of its input unless the filters
885 // have reduced the input to a single row, in which case the
886 // keys of the input are `()`.
887 let mut input = input_keys.next().unwrap().clone();
888
889 if !input.is_empty() {
890 // Track columns equated to literals, which we can prune.
891 let mut cols_equal_to_literal = BTreeSet::new();
892
893 // Perform union find on `col1 = col2` to establish
894 // connected components of equated columns. Absent any
895 // equalities, this will be `0 .. #c` (where #c is the
896 // greatest column referenced by a predicate), but each
897 // equality will orient the root of the greater to the root
898 // of the lesser.
899 let mut union_find = Vec::new();
900
901 for expr in predicates.iter() {
902 if let MirScalarExpr::CallBinary {
903 func: crate::BinaryFunc::Eq,
904 expr1,
905 expr2,
906 } = expr
907 {
908 if let MirScalarExpr::Column(c, _name) = &**expr1 {
909 if expr2.is_literal_ok() {
910 cols_equal_to_literal.insert(c);
911 }
912 }
913 if let MirScalarExpr::Column(c, _name) = &**expr2 {
914 if expr1.is_literal_ok() {
915 cols_equal_to_literal.insert(c);
916 }
917 }
918 // Perform union-find to equate columns.
919 if let (Some(c1), Some(c2)) = (expr1.as_column(), expr2.as_column()) {
920 if c1 != c2 {
921 // Ensure union_find has entries up to
922 // max(c1, c2) by filling up missing
923 // positions with identity mappings.
924 while union_find.len() <= std::cmp::max(c1, c2) {
925 union_find.push(union_find.len());
926 }
927 let mut r1 = c1; // Find the representative column of [c1].
928 while r1 != union_find[r1] {
929 assert!(union_find[r1] < r1);
930 r1 = union_find[r1];
931 }
932 let mut r2 = c2; // Find the representative column of [c2].
933 while r2 != union_find[r2] {
934 assert!(union_find[r2] < r2);
935 r2 = union_find[r2];
936 }
937 // Union [c1] and [c2] by pointing the
938 // larger to the smaller representative (we
939 // update the remaining equivalence class
940 // members only once after this for-loop).
941 union_find[std::cmp::max(r1, r2)] = std::cmp::min(r1, r2);
942 }
943 }
944 }
945 }
946
947 // Complete union-find by pointing each element at its representative column.
948 for i in 0..union_find.len() {
949 // Iteration not required, as each prior already references the right column.
950 union_find[i] = union_find[union_find[i]];
951 }
952
953 // Remove columns bound to literals, and remap columns equated to earlier columns.
954 // We will re-expand remapped columns in a moment, but this avoids exponential work.
955 for key_set in &mut input {
956 key_set.retain(|k| !cols_equal_to_literal.contains(&k));
957 for col in key_set.iter_mut() {
958 if let Some(equiv) = union_find.get(*col) {
959 *col = *equiv;
960 }
961 }
962 key_set.sort();
963 key_set.dedup();
964 }
965 input.sort();
966 input.dedup();
967
968 // Expand out each key to each of its equivalent forms.
969 // Each instance of `col` can be replaced by any equivalent column.
970 // This has the potential to result in exponentially sized number of unique keys,
971 // and in the future we should probably maintain unique keys modulo equivalence.
972
973 // First, compute an inverse map from each representative
974 // column `sub` to all other equivalent columns `col`.
975 let mut subs = Vec::new();
976 for (col, sub) in union_find.iter().enumerate() {
977 if *sub != col {
978 assert!(*sub < col);
979 while subs.len() <= *sub {
980 subs.push(Vec::new());
981 }
982 subs[*sub].push(col);
983 }
984 }
985 // For each column, substitute for it in each occurrence.
986 let mut to_add = Vec::new();
987 for (col, subs) in subs.iter().enumerate() {
988 if !subs.is_empty() {
989 for key_set in input.iter() {
990 if key_set.contains(&col) {
991 let mut to_extend = key_set.clone();
992 to_extend.retain(|c| c != &col);
993 for sub in subs {
994 to_extend.push(*sub);
995 to_add.push(to_extend.clone());
996 to_extend.pop();
997 }
998 }
999 }
1000 }
1001 // No deduplication, as we cannot introduce duplicates.
1002 input.append(&mut to_add);
1003 }
1004 for key_set in input.iter_mut() {
1005 key_set.sort();
1006 key_set.dedup();
1007 }
1008 }
1009 input
1010 }
1011 Join { equivalences, .. } => {
1012 // It is important the `new_from_input_arities` constructor is
1013 // used. Otherwise, Materialize may potentially end up in an
1014 // infinite loop.
1015 let input_mapper = crate::JoinInputMapper::new_from_input_arities(input_arities);
1016
1017 input_mapper.global_keys(input_keys, equivalences)
1018 }
1019 Reduce { group_key, .. } => {
1020 // The group key should form a key, but we might already have
1021 // keys that are subsets of the group key, and should retain
1022 // those instead, if so.
1023 let mut result = Vec::new();
1024 for key_set in input_keys.next().unwrap() {
1025 if key_set
1026 .iter()
1027 .all(|k| group_key.contains(&MirScalarExpr::column(*k)))
1028 {
1029 result.push(
1030 key_set
1031 .iter()
1032 .map(|i| {
1033 group_key
1034 .iter()
1035 .position(|k| k == &MirScalarExpr::column(*i))
1036 .unwrap()
1037 })
1038 .collect::<Vec<_>>(),
1039 );
1040 }
1041 }
1042 if result.is_empty() {
1043 result.push((0..group_key.len()).collect());
1044 }
1045 result
1046 }
1047 TopK {
1048 group_key, limit, ..
1049 } => {
1050 // If `limit` is `Some(1)` then the group key will become
1051 // a unique key, as there will be only one record with that key.
1052 let mut result = input_keys.next().unwrap().clone();
1053 if limit.as_ref().and_then(|x| x.as_literal_int64()) == Some(1) {
1054 result.push(group_key.clone())
1055 }
1056 result
1057 }
1058 Union { base, inputs } => {
1059 // Generally, unions do not have any unique keys, because
1060 // each input might duplicate some. However, there is at
1061 // least one idiomatic structure that does preserve keys,
1062 // which results from SQL aggregations that must populate
1063 // absent records with default values. In that pattern,
1064 // the union of one GET with its negation, which has first
1065 // been subjected to a projection and map, we can remove
1066 // their influence on the key structure.
1067 //
1068 // If there are A, B, each with a unique `key` such that
1069 // we are looking at
1070 //
1071 // A.proj(set_containing_key) + (B - A.proj(key)).map(stuff)
1072 //
1073 // Then we can report `key` as a unique key.
1074 //
1075 // TODO: make unique key structure an optimization analysis
1076 // rather than part of the type information.
1077 // TODO: perhaps ensure that (above) A.proj(key) is a
1078 // subset of B, as otherwise there are negative records
1079 // and who knows what is true (not expected, but again
1080 // who knows what the query plan might look like).
1081
1082 let arity = input_arities.next().unwrap();
1083 let (base_projection, base_with_project_stripped) =
1084 if let MirRelationExpr::Project { input, outputs } = &**base {
1085 (outputs.clone(), &**input)
1086 } else {
1087 // A input without a project is equivalent to an input
1088 // with the project being all columns in the input in order.
1089 ((0..arity).collect::<Vec<_>>(), &**base)
1090 };
1091 let mut result = Vec::new();
1092 if let MirRelationExpr::Get {
1093 id: first_id,
1094 typ: _,
1095 ..
1096 } = base_with_project_stripped
1097 {
1098 if inputs.len() == 1 {
1099 if let MirRelationExpr::Map { input, .. } = &inputs[0] {
1100 if let MirRelationExpr::Union { base, inputs } = &**input {
1101 if inputs.len() == 1 {
1102 if let Some((input, outputs)) = base.is_negated_project() {
1103 if let MirRelationExpr::Get {
1104 id: second_id,
1105 typ: _,
1106 ..
1107 } = input
1108 {
1109 if first_id == second_id {
1110 result.extend(
1111 input_keys
1112 .next()
1113 .unwrap()
1114 .into_iter()
1115 .filter(|key| {
1116 key.iter().all(|c| {
1117 outputs.get(*c) == Some(c)
1118 && base_projection.get(*c)
1119 == Some(c)
1120 })
1121 })
1122 .cloned(),
1123 );
1124 }
1125 }
1126 }
1127 }
1128 }
1129 }
1130 }
1131 }
1132 // Important: do not inherit keys of either input, as not unique.
1133 result
1134 }
1135 };
1136 keys.sort();
1137 keys.dedup();
1138 keys
1139 }
1140
1141 /// The number of columns in the relation.
1142 ///
1143 /// This number is determined from the type, which is determined recursively
1144 /// at non-trivial cost.
1145 ///
1146 /// The arity is computed incrementally with a recursive post-order
1147 /// traversal, that accumulates the arities for the relations yet to be
1148 /// visited in `arity_stack`.
1149 pub fn arity(&self) -> usize {
1150 let mut arity_stack = Vec::new();
1151 #[allow(deprecated)]
1152 self.visit_pre_post_nolimit(
1153 &mut |e: &MirRelationExpr| -> Option<Vec<&MirRelationExpr>> {
1154 match &e {
1155 MirRelationExpr::Let { body, .. } => {
1156 // Do not traverse the value sub-graph, since it's not relevant for
1157 // determining the arity of Let operators.
1158 Some(vec![&*body])
1159 }
1160 MirRelationExpr::LetRec { body, .. } => {
1161 // Do not traverse the value sub-graph, since it's not relevant for
1162 // determining the arity of Let operators.
1163 Some(vec![&*body])
1164 }
1165 MirRelationExpr::Project { .. } | MirRelationExpr::Reduce { .. } => {
1166 // No further traversal is required; these operators know their arity.
1167 Some(Vec::new())
1168 }
1169 _ => None,
1170 }
1171 },
1172 &mut |e: &MirRelationExpr| {
1173 match &e {
1174 MirRelationExpr::Let { .. } => {
1175 let body_arity = arity_stack.pop().unwrap();
1176 arity_stack.push(0);
1177 arity_stack.push(body_arity);
1178 }
1179 MirRelationExpr::LetRec { values, .. } => {
1180 let body_arity = arity_stack.pop().unwrap();
1181 arity_stack.extend(std::iter::repeat(0).take(values.len()));
1182 arity_stack.push(body_arity);
1183 }
1184 MirRelationExpr::Project { .. } | MirRelationExpr::Reduce { .. } => {
1185 arity_stack.push(0);
1186 }
1187 _ => {}
1188 }
1189 let num_inputs = e.num_inputs();
1190 let input_arities = arity_stack.drain(arity_stack.len() - num_inputs..);
1191 let arity = e.arity_with_input_arities(input_arities);
1192 arity_stack.push(arity);
1193 },
1194 );
1195 assert_eq!(arity_stack.len(), 1);
1196 arity_stack.pop().unwrap()
1197 }
1198
1199 /// Reports the arity of the relation given the schema of the input relations.
1200 ///
1201 /// `input_arities` is required to contain the arities for the input relations of
1202 /// the current relation in the same order as they are visited by `try_visit_children`
1203 /// method, even though not all may be used for computing the schema of the
1204 /// current relation. For example, `Let` expects two input types, one for the
1205 /// value relation and one for the body, in that order, but only the one for the
1206 /// body is used to determine the type of the `Let` relation.
1207 ///
1208 /// It is meant to be used during post-order traversals to compute arities
1209 /// incrementally.
1210 pub fn arity_with_input_arities<I>(&self, mut input_arities: I) -> usize
1211 where
1212 I: Iterator<Item = usize>,
1213 {
1214 use MirRelationExpr::*;
1215
1216 match self {
1217 Constant { rows: _, typ } => typ.arity(),
1218 Get { typ, .. } => typ.arity(),
1219 Let { .. } => {
1220 input_arities.next();
1221 input_arities.next().unwrap()
1222 }
1223 LetRec { values, .. } => {
1224 for _ in 0..values.len() {
1225 input_arities.next();
1226 }
1227 input_arities.next().unwrap()
1228 }
1229 Project { outputs, .. } => outputs.len(),
1230 Map { scalars, .. } => input_arities.next().unwrap() + scalars.len(),
1231 FlatMap { func, .. } => input_arities.next().unwrap() + func.output_arity(),
1232 Join { .. } => input_arities.sum(),
1233 Reduce {
1234 input: _,
1235 group_key,
1236 aggregates,
1237 ..
1238 } => group_key.len() + aggregates.len(),
1239 Filter { .. }
1240 | TopK { .. }
1241 | Negate { .. }
1242 | Threshold { .. }
1243 | Union { .. }
1244 | ArrangeBy { .. } => input_arities.next().unwrap(),
1245 }
1246 }
1247
1248 /// The number of child relations this relation has.
1249 pub fn num_inputs(&self) -> usize {
1250 let mut count = 0;
1251
1252 self.visit_children(|_| count += 1);
1253
1254 count
1255 }
1256
1257 /// Constructs a constant collection from specific rows and schema, where
1258 /// each row will have a multiplicity of one.
1259 pub fn constant(rows: Vec<Vec<Datum>>, typ: RelationType) -> Self {
1260 let rows = rows.into_iter().map(|row| (row, Diff::ONE)).collect();
1261 MirRelationExpr::constant_diff(rows, typ)
1262 }
1263
1264 /// Constructs a constant collection from specific rows and schema, where
1265 /// each row can have an arbitrary multiplicity.
1266 pub fn constant_diff(rows: Vec<(Vec<Datum>, Diff)>, typ: RelationType) -> Self {
1267 for (row, _diff) in &rows {
1268 for (datum, column_typ) in row.iter().zip(typ.column_types.iter()) {
1269 assert!(
1270 datum.is_instance_of(column_typ),
1271 "Expected datum of type {:?}, got value {:?}",
1272 column_typ,
1273 datum
1274 );
1275 }
1276 }
1277 let rows = Ok(rows
1278 .into_iter()
1279 .map(move |(row, diff)| (Row::pack_slice(&row), diff))
1280 .collect());
1281 MirRelationExpr::Constant { rows, typ }
1282 }
1283
1284 /// If self is a constant, return the value and the type, otherwise `None`.
1285 /// Looks behind `ArrangeBy`s.
1286 pub fn as_const(&self) -> Option<(&Result<Vec<(Row, Diff)>, EvalError>, &RelationType)> {
1287 match self {
1288 MirRelationExpr::Constant { rows, typ } => Some((rows, typ)),
1289 MirRelationExpr::ArrangeBy { input, .. } => input.as_const(),
1290 _ => None,
1291 }
1292 }
1293
1294 /// If self is a constant, mutably return the value and the type, otherwise `None`.
1295 /// Looks behind `ArrangeBy`s.
1296 pub fn as_const_mut(
1297 &mut self,
1298 ) -> Option<(&mut Result<Vec<(Row, Diff)>, EvalError>, &mut RelationType)> {
1299 match self {
1300 MirRelationExpr::Constant { rows, typ } => Some((rows, typ)),
1301 MirRelationExpr::ArrangeBy { input, .. } => input.as_const_mut(),
1302 _ => None,
1303 }
1304 }
1305
1306 /// If self is a constant error, return the error, otherwise `None`.
1307 /// Looks behind `ArrangeBy`s.
1308 pub fn as_const_err(&self) -> Option<&EvalError> {
1309 match self {
1310 MirRelationExpr::Constant { rows: Err(e), .. } => Some(e),
1311 MirRelationExpr::ArrangeBy { input, .. } => input.as_const_err(),
1312 _ => None,
1313 }
1314 }
1315
1316 /// Checks if `self` is the single element collection with no columns.
1317 pub fn is_constant_singleton(&self) -> bool {
1318 if let Some((Ok(rows), typ)) = self.as_const() {
1319 rows.len() == 1 && typ.column_types.len() == 0 && rows[0].1 == Diff::ONE
1320 } else {
1321 false
1322 }
1323 }
1324
1325 /// Constructs the expression for getting a local collection.
1326 pub fn local_get(id: LocalId, typ: RelationType) -> Self {
1327 MirRelationExpr::Get {
1328 id: Id::Local(id),
1329 typ,
1330 access_strategy: AccessStrategy::UnknownOrLocal,
1331 }
1332 }
1333
1334 /// Constructs the expression for getting a global collection
1335 pub fn global_get(id: GlobalId, typ: RelationType) -> Self {
1336 MirRelationExpr::Get {
1337 id: Id::Global(id),
1338 typ,
1339 access_strategy: AccessStrategy::UnknownOrLocal,
1340 }
1341 }
1342
1343 /// Retains only the columns specified by `output`.
1344 pub fn project(mut self, mut outputs: Vec<usize>) -> Self {
1345 if let MirRelationExpr::Project {
1346 outputs: columns, ..
1347 } = &mut self
1348 {
1349 // Update `outputs` to reference base columns of `input`.
1350 for column in outputs.iter_mut() {
1351 *column = columns[*column];
1352 }
1353 *columns = outputs;
1354 self
1355 } else {
1356 MirRelationExpr::Project {
1357 input: Box::new(self),
1358 outputs,
1359 }
1360 }
1361 }
1362
1363 /// Append to each row the results of applying elements of `scalar`.
1364 pub fn map(mut self, scalars: Vec<MirScalarExpr>) -> Self {
1365 if let MirRelationExpr::Map { scalars: s, .. } = &mut self {
1366 s.extend(scalars);
1367 self
1368 } else if !scalars.is_empty() {
1369 MirRelationExpr::Map {
1370 input: Box::new(self),
1371 scalars,
1372 }
1373 } else {
1374 self
1375 }
1376 }
1377
1378 /// Append to each row a single `scalar`.
1379 pub fn map_one(self, scalar: MirScalarExpr) -> Self {
1380 self.map(vec![scalar])
1381 }
1382
1383 /// Like `map`, but yields zero-or-more output rows per input row
1384 pub fn flat_map(self, func: TableFunc, exprs: Vec<MirScalarExpr>) -> Self {
1385 MirRelationExpr::FlatMap {
1386 input: Box::new(self),
1387 func,
1388 exprs,
1389 }
1390 }
1391
1392 /// Retain only the rows satisfying each of several predicates.
1393 pub fn filter<I>(mut self, predicates: I) -> Self
1394 where
1395 I: IntoIterator<Item = MirScalarExpr>,
1396 {
1397 // Extract existing predicates
1398 let mut new_predicates = if let MirRelationExpr::Filter { input, predicates } = self {
1399 self = *input;
1400 predicates
1401 } else {
1402 Vec::new()
1403 };
1404 // Normalize collection of predicates.
1405 new_predicates.extend(predicates);
1406 new_predicates.retain(|p| !p.is_literal_true());
1407 new_predicates.sort();
1408 new_predicates.dedup();
1409 // Introduce a `Filter` only if we have predicates.
1410 if !new_predicates.is_empty() {
1411 self = MirRelationExpr::Filter {
1412 input: Box::new(self),
1413 predicates: new_predicates,
1414 };
1415 }
1416
1417 self
1418 }
1419
1420 /// Form the Cartesian outer-product of rows in both inputs.
1421 pub fn product(mut self, right: Self) -> Self {
1422 if right.is_constant_singleton() {
1423 self
1424 } else if self.is_constant_singleton() {
1425 right
1426 } else if let MirRelationExpr::Join { inputs, .. } = &mut self {
1427 inputs.push(right);
1428 self
1429 } else {
1430 MirRelationExpr::join(vec![self, right], vec![])
1431 }
1432 }
1433
1434 /// Performs a relational equijoin among the input collections.
1435 ///
1436 /// The sequence `inputs` each describe different input collections, and the sequence `variables` describes
1437 /// equality constraints that some of their columns must satisfy. Each element in `variable` describes a set
1438 /// of pairs `(input_index, column_index)` where every value described by that set must be equal.
1439 ///
1440 /// For example, the pair `(input, column)` indexes into `inputs[input][column]`, extracting the `input`th
1441 /// input collection and for each row examining its `column`th column.
1442 ///
1443 /// # Example
1444 ///
1445 /// ```rust
1446 /// use mz_repr::{Datum, ColumnType, RelationType, ScalarType};
1447 /// use mz_expr::MirRelationExpr;
1448 ///
1449 /// // A common schema for each input.
1450 /// let schema = RelationType::new(vec![
1451 /// ScalarType::Int32.nullable(false),
1452 /// ScalarType::Int32.nullable(false),
1453 /// ]);
1454 ///
1455 /// // the specific data are not important here.
1456 /// let data = vec![Datum::Int32(0), Datum::Int32(1)];
1457 ///
1458 /// // Three collections that could have been different.
1459 /// let input0 = MirRelationExpr::constant(vec![data.clone()], schema.clone());
1460 /// let input1 = MirRelationExpr::constant(vec![data.clone()], schema.clone());
1461 /// let input2 = MirRelationExpr::constant(vec![data.clone()], schema.clone());
1462 ///
1463 /// // Join the three relations looking for triangles, like so.
1464 /// //
1465 /// // Output(A,B,C) := Input0(A,B), Input1(B,C), Input2(A,C)
1466 /// let joined = MirRelationExpr::join(
1467 /// vec![input0, input1, input2],
1468 /// vec![
1469 /// vec![(0,0), (2,0)], // fields A of inputs 0 and 2.
1470 /// vec![(0,1), (1,0)], // fields B of inputs 0 and 1.
1471 /// vec![(1,1), (2,1)], // fields C of inputs 1 and 2.
1472 /// ],
1473 /// );
1474 ///
1475 /// // Technically the above produces `Output(A,B,B,C,A,C)` because the columns are concatenated.
1476 /// // A projection resolves this and produces the correct output.
1477 /// let result = joined.project(vec![0, 1, 3]);
1478 /// ```
1479 pub fn join(inputs: Vec<MirRelationExpr>, variables: Vec<Vec<(usize, usize)>>) -> Self {
1480 let input_mapper = join_input_mapper::JoinInputMapper::new(&inputs);
1481
1482 let equivalences = variables
1483 .into_iter()
1484 .map(|vs| {
1485 vs.into_iter()
1486 .map(|(r, c)| input_mapper.map_expr_to_global(MirScalarExpr::column(c), r))
1487 .collect::<Vec<_>>()
1488 })
1489 .collect::<Vec<_>>();
1490
1491 Self::join_scalars(inputs, equivalences)
1492 }
1493
1494 /// Constructs a join operator from inputs and required-equal scalar expressions.
1495 pub fn join_scalars(
1496 mut inputs: Vec<MirRelationExpr>,
1497 equivalences: Vec<Vec<MirScalarExpr>>,
1498 ) -> Self {
1499 // Remove all constant inputs that are the identity for join.
1500 // They neither introduce nor modify any column references.
1501 inputs.retain(|i| !i.is_constant_singleton());
1502 MirRelationExpr::Join {
1503 inputs,
1504 equivalences,
1505 implementation: JoinImplementation::Unimplemented,
1506 }
1507 }
1508
1509 /// Perform a key-wise reduction / aggregation.
1510 ///
1511 /// The `group_key` argument indicates columns in the input collection that should
1512 /// be grouped, and `aggregates` lists aggregation functions each of which produces
1513 /// one output column in addition to the keys.
1514 pub fn reduce(
1515 self,
1516 group_key: Vec<usize>,
1517 aggregates: Vec<AggregateExpr>,
1518 expected_group_size: Option<u64>,
1519 ) -> Self {
1520 MirRelationExpr::Reduce {
1521 input: Box::new(self),
1522 group_key: group_key.into_iter().map(MirScalarExpr::column).collect(),
1523 aggregates,
1524 monotonic: false,
1525 expected_group_size,
1526 }
1527 }
1528
1529 /// Perform a key-wise reduction order by and limit.
1530 ///
1531 /// The `group_key` argument indicates columns in the input collection that should
1532 /// be grouped, the `order_key` argument indicates columns that should be further
1533 /// used to order records within groups, and the `limit` argument constrains the
1534 /// total number of records that should be produced in each group.
1535 pub fn top_k(
1536 self,
1537 group_key: Vec<usize>,
1538 order_key: Vec<ColumnOrder>,
1539 limit: Option<MirScalarExpr>,
1540 offset: usize,
1541 expected_group_size: Option<u64>,
1542 ) -> Self {
1543 MirRelationExpr::TopK {
1544 input: Box::new(self),
1545 group_key,
1546 order_key,
1547 limit,
1548 offset,
1549 expected_group_size,
1550 monotonic: false,
1551 }
1552 }
1553
1554 /// Negates the occurrences of each row.
1555 pub fn negate(self) -> Self {
1556 if let MirRelationExpr::Negate { input } = self {
1557 *input
1558 } else {
1559 MirRelationExpr::Negate {
1560 input: Box::new(self),
1561 }
1562 }
1563 }
1564
1565 /// Removes all but the first occurrence of each row.
1566 pub fn distinct(self) -> Self {
1567 let arity = self.arity();
1568 self.distinct_by((0..arity).collect())
1569 }
1570
1571 /// Removes all but the first occurrence of each key. Columns not included
1572 /// in the `group_key` are discarded.
1573 pub fn distinct_by(self, group_key: Vec<usize>) -> Self {
1574 self.reduce(group_key, vec![], None)
1575 }
1576
1577 /// Discards rows with a negative frequency.
1578 pub fn threshold(self) -> Self {
1579 if let MirRelationExpr::Threshold { .. } = &self {
1580 self
1581 } else {
1582 MirRelationExpr::Threshold {
1583 input: Box::new(self),
1584 }
1585 }
1586 }
1587
1588 /// Unions together any number inputs.
1589 ///
1590 /// If `inputs` is empty, then an empty relation of type `typ` is
1591 /// constructed.
1592 pub fn union_many(mut inputs: Vec<Self>, typ: RelationType) -> Self {
1593 // Deconstruct `inputs` as `Union`s and reconstitute.
1594 let mut flat_inputs = Vec::with_capacity(inputs.len());
1595 for input in inputs {
1596 if let MirRelationExpr::Union { base, inputs } = input {
1597 flat_inputs.push(*base);
1598 flat_inputs.extend(inputs);
1599 } else {
1600 flat_inputs.push(input);
1601 }
1602 }
1603 inputs = flat_inputs;
1604 if inputs.len() == 0 {
1605 MirRelationExpr::Constant {
1606 rows: Ok(vec![]),
1607 typ,
1608 }
1609 } else if inputs.len() == 1 {
1610 inputs.into_element()
1611 } else {
1612 MirRelationExpr::Union {
1613 base: Box::new(inputs.remove(0)),
1614 inputs,
1615 }
1616 }
1617 }
1618
1619 /// Produces one collection where each row is present with the sum of its frequencies in each input.
1620 pub fn union(self, other: Self) -> Self {
1621 // Deconstruct `self` and `other` as `Union`s and reconstitute.
1622 let mut flat_inputs = Vec::with_capacity(2);
1623 if let MirRelationExpr::Union { base, inputs } = self {
1624 flat_inputs.push(*base);
1625 flat_inputs.extend(inputs);
1626 } else {
1627 flat_inputs.push(self);
1628 }
1629 if let MirRelationExpr::Union { base, inputs } = other {
1630 flat_inputs.push(*base);
1631 flat_inputs.extend(inputs);
1632 } else {
1633 flat_inputs.push(other);
1634 }
1635
1636 MirRelationExpr::Union {
1637 base: Box::new(flat_inputs.remove(0)),
1638 inputs: flat_inputs,
1639 }
1640 }
1641
1642 /// Arranges the collection by the specified columns
1643 pub fn arrange_by(self, keys: &[Vec<MirScalarExpr>]) -> Self {
1644 MirRelationExpr::ArrangeBy {
1645 input: Box::new(self),
1646 keys: keys.to_owned(),
1647 }
1648 }
1649
1650 /// Indicates if this is a constant empty collection.
1651 ///
1652 /// A false value does not mean the collection is known to be non-empty,
1653 /// only that we cannot currently determine that it is statically empty.
1654 pub fn is_empty(&self) -> bool {
1655 if let Some((Ok(rows), ..)) = self.as_const() {
1656 rows.is_empty()
1657 } else {
1658 false
1659 }
1660 }
1661
1662 /// If the expression is a negated project, return the input and the projection.
1663 pub fn is_negated_project(&self) -> Option<(&MirRelationExpr, &[usize])> {
1664 if let MirRelationExpr::Negate { input } = self {
1665 if let MirRelationExpr::Project { input, outputs } = &**input {
1666 return Some((&**input, outputs));
1667 }
1668 }
1669 if let MirRelationExpr::Project { input, outputs } = self {
1670 if let MirRelationExpr::Negate { input } = &**input {
1671 return Some((&**input, outputs));
1672 }
1673 }
1674 None
1675 }
1676
1677 /// Pretty-print this [MirRelationExpr] to a string.
1678 pub fn pretty(&self) -> String {
1679 let config = ExplainConfig::default();
1680 self.explain(&config, None)
1681 }
1682
1683 /// Pretty-print this [MirRelationExpr] to a string using a custom
1684 /// [ExplainConfig] and an optionally provided [ExprHumanizer].
1685 pub fn explain(&self, config: &ExplainConfig, humanizer: Option<&dyn ExprHumanizer>) -> String {
1686 text_string_at(self, || PlanRenderingContext {
1687 indent: Indent::default(),
1688 humanizer: humanizer.unwrap_or(&DummyHumanizer),
1689 annotations: BTreeMap::default(),
1690 config,
1691 })
1692 }
1693
1694 /// Take ownership of `self`, leaving an empty `MirRelationExpr::Constant` with the optionally
1695 /// given scalar types. The given scalar types should be `base_eq` with the types that `typ()`
1696 /// would find. Keys and nullability are ignored in the given `RelationType`, and instead we set
1697 /// the best possible key and nullability, since we are making an empty collection.
1698 ///
1699 /// If `typ` is not given, then this calls `.typ()` (which is possibly expensive) to determine
1700 /// the correct type.
1701 pub fn take_safely(&mut self, typ: Option<RelationType>) -> MirRelationExpr {
1702 if let Some(typ) = &typ {
1703 soft_assert_no_log!(
1704 self.typ()
1705 .column_types
1706 .iter()
1707 .zip_eq(typ.column_types.iter())
1708 .all(|(t1, t2)| t1.scalar_type.base_eq(&t2.scalar_type))
1709 );
1710 }
1711 let mut typ = typ.unwrap_or_else(|| self.typ());
1712 typ.keys = vec![vec![]];
1713 for ct in typ.column_types.iter_mut() {
1714 ct.nullable = false;
1715 }
1716 std::mem::replace(
1717 self,
1718 MirRelationExpr::Constant {
1719 rows: Ok(vec![]),
1720 typ,
1721 },
1722 )
1723 }
1724
1725 /// Take ownership of `self`, leaving an empty `MirRelationExpr::Constant` with the given scalar
1726 /// types. Nullability is ignored in the given `ColumnType`s, and instead we set the best
1727 /// possible nullability, since we are making an empty collection.
1728 pub fn take_safely_with_col_types(&mut self, typ: Vec<ColumnType>) -> MirRelationExpr {
1729 self.take_safely(Some(RelationType::new(typ)))
1730 }
1731
1732 /// Take ownership of `self`, leaving an empty `MirRelationExpr::Constant` with an **incorrect** type.
1733 ///
1734 /// This should only be used if `self` is about to be dropped or otherwise overwritten.
1735 pub fn take_dangerous(&mut self) -> MirRelationExpr {
1736 let empty = MirRelationExpr::Constant {
1737 rows: Ok(vec![]),
1738 typ: RelationType::new(Vec::new()),
1739 };
1740 std::mem::replace(self, empty)
1741 }
1742
1743 /// Replaces `self` with some logic applied to `self`.
1744 pub fn replace_using<F>(&mut self, logic: F)
1745 where
1746 F: FnOnce(MirRelationExpr) -> MirRelationExpr,
1747 {
1748 let empty = MirRelationExpr::Constant {
1749 rows: Ok(vec![]),
1750 typ: RelationType::new(Vec::new()),
1751 };
1752 let expr = std::mem::replace(self, empty);
1753 *self = logic(expr);
1754 }
1755
1756 /// Store `self` in a `Let` and pass the corresponding `Get` to `body`.
1757 pub fn let_in<Body, E>(self, id_gen: &mut IdGen, body: Body) -> Result<MirRelationExpr, E>
1758 where
1759 Body: FnOnce(&mut IdGen, MirRelationExpr) -> Result<MirRelationExpr, E>,
1760 {
1761 if let MirRelationExpr::Get { .. } = self {
1762 // already done
1763 body(id_gen, self)
1764 } else {
1765 let id = LocalId::new(id_gen.allocate_id());
1766 let get = MirRelationExpr::Get {
1767 id: Id::Local(id),
1768 typ: self.typ(),
1769 access_strategy: AccessStrategy::UnknownOrLocal,
1770 };
1771 let body = (body)(id_gen, get)?;
1772 Ok(MirRelationExpr::Let {
1773 id,
1774 value: Box::new(self),
1775 body: Box::new(body),
1776 })
1777 }
1778 }
1779
1780 /// Return every row in `self` that does not have a matching row in the first columns of `keys_and_values`, using `default` to fill in the remaining columns
1781 /// (If `default` is a row of nulls, this is the 'outer' part of LEFT OUTER JOIN)
1782 pub fn anti_lookup<E>(
1783 self,
1784 id_gen: &mut IdGen,
1785 keys_and_values: MirRelationExpr,
1786 default: Vec<(Datum, ScalarType)>,
1787 ) -> Result<MirRelationExpr, E> {
1788 let (data, column_types): (Vec<_>, Vec<_>) = default
1789 .into_iter()
1790 .map(|(datum, scalar_type)| (datum, scalar_type.nullable(datum.is_null())))
1791 .unzip();
1792 assert_eq!(keys_and_values.arity() - self.arity(), data.len());
1793 self.let_in(id_gen, |_id_gen, get_keys| {
1794 let get_keys_arity = get_keys.arity();
1795 Ok(MirRelationExpr::join(
1796 vec![
1797 // all the missing keys (with count 1)
1798 keys_and_values
1799 .distinct_by((0..get_keys_arity).collect())
1800 .negate()
1801 .union(get_keys.clone().distinct()),
1802 // join with keys to get the correct counts
1803 get_keys.clone(),
1804 ],
1805 (0..get_keys_arity).map(|i| vec![(0, i), (1, i)]).collect(),
1806 )
1807 // get rid of the extra copies of columns from keys
1808 .project((0..get_keys_arity).collect())
1809 // This join is logically equivalent to
1810 // `.map(<default_expr>)`, but using a join allows for
1811 // potential predicate pushdown and elision in the
1812 // optimizer.
1813 .product(MirRelationExpr::constant(
1814 vec![data],
1815 RelationType::new(column_types),
1816 )))
1817 })
1818 }
1819
1820 /// Return:
1821 /// * every row in keys_and_values
1822 /// * every row in `self` that does not have a matching row in the first columns of
1823 /// `keys_and_values`, using `default` to fill in the remaining columns
1824 /// (This is LEFT OUTER JOIN if:
1825 /// 1) `default` is a row of null
1826 /// 2) matching rows in `keys_and_values` and `self` have the same multiplicity.)
1827 pub fn lookup<E>(
1828 self,
1829 id_gen: &mut IdGen,
1830 keys_and_values: MirRelationExpr,
1831 default: Vec<(Datum<'static>, ScalarType)>,
1832 ) -> Result<MirRelationExpr, E> {
1833 keys_and_values.let_in(id_gen, |id_gen, get_keys_and_values| {
1834 Ok(get_keys_and_values.clone().union(self.anti_lookup(
1835 id_gen,
1836 get_keys_and_values,
1837 default,
1838 )?))
1839 })
1840 }
1841
1842 /// True iff the expression contains a `NullaryFunc::MzLogicalTimestamp`.
1843 pub fn contains_temporal(&self) -> bool {
1844 let mut contains = false;
1845 self.visit_scalars(&mut |e| contains = contains || e.contains_temporal());
1846 contains
1847 }
1848
1849 /// Fallible visitor for the [`MirScalarExpr`]s directly owned by this relation expression.
1850 ///
1851 /// The `f` visitor should not recursively descend into owned [`MirRelationExpr`]s.
1852 pub fn try_visit_scalars_mut1<F, E>(&mut self, f: &mut F) -> Result<(), E>
1853 where
1854 F: FnMut(&mut MirScalarExpr) -> Result<(), E>,
1855 {
1856 use MirRelationExpr::*;
1857 match self {
1858 Map { scalars, .. } => {
1859 for s in scalars {
1860 f(s)?;
1861 }
1862 }
1863 Filter { predicates, .. } => {
1864 for p in predicates {
1865 f(p)?;
1866 }
1867 }
1868 FlatMap { exprs, .. } => {
1869 for expr in exprs {
1870 f(expr)?;
1871 }
1872 }
1873 Join {
1874 inputs: _,
1875 equivalences,
1876 implementation,
1877 } => {
1878 for equivalence in equivalences {
1879 for expr in equivalence {
1880 f(expr)?;
1881 }
1882 }
1883 match implementation {
1884 JoinImplementation::Differential((_, start_key, _), order) => {
1885 if let Some(start_key) = start_key {
1886 for k in start_key {
1887 f(k)?;
1888 }
1889 }
1890 for (_, lookup_key, _) in order {
1891 for k in lookup_key {
1892 f(k)?;
1893 }
1894 }
1895 }
1896 JoinImplementation::DeltaQuery(paths) => {
1897 for path in paths {
1898 for (_, lookup_key, _) in path {
1899 for k in lookup_key {
1900 f(k)?;
1901 }
1902 }
1903 }
1904 }
1905 JoinImplementation::IndexedFilter(_coll_id, _idx_id, index_key, _) => {
1906 for k in index_key {
1907 f(k)?;
1908 }
1909 }
1910 JoinImplementation::Unimplemented => {} // No scalar exprs
1911 }
1912 }
1913 ArrangeBy { keys, .. } => {
1914 for key in keys {
1915 for s in key {
1916 f(s)?;
1917 }
1918 }
1919 }
1920 Reduce {
1921 group_key,
1922 aggregates,
1923 ..
1924 } => {
1925 for s in group_key {
1926 f(s)?;
1927 }
1928 for agg in aggregates {
1929 f(&mut agg.expr)?;
1930 }
1931 }
1932 TopK { limit, .. } => {
1933 if let Some(s) = limit {
1934 f(s)?;
1935 }
1936 }
1937 Constant { .. }
1938 | Get { .. }
1939 | Let { .. }
1940 | LetRec { .. }
1941 | Project { .. }
1942 | Negate { .. }
1943 | Threshold { .. }
1944 | Union { .. } => (),
1945 }
1946 Ok(())
1947 }
1948
1949 /// Fallible mutable visitor for the [`MirScalarExpr`]s in the [`MirRelationExpr`] subtree
1950 /// rooted at `self`.
1951 ///
1952 /// Note that this does not recurse into [`MirRelationExpr`] subtrees within [`MirScalarExpr`]
1953 /// nodes.
1954 pub fn try_visit_scalars_mut<F, E>(&mut self, f: &mut F) -> Result<(), E>
1955 where
1956 F: FnMut(&mut MirScalarExpr) -> Result<(), E>,
1957 E: From<RecursionLimitError>,
1958 {
1959 self.try_visit_mut_post(&mut |expr| expr.try_visit_scalars_mut1(f))
1960 }
1961
1962 /// Infallible mutable visitor for the [`MirScalarExpr`]s in the [`MirRelationExpr`] subtree
1963 /// rooted at `self`.
1964 ///
1965 /// Note that this does not recurse into [`MirRelationExpr`] subtrees within [`MirScalarExpr`]
1966 /// nodes.
1967 pub fn visit_scalars_mut<F>(&mut self, f: &mut F)
1968 where
1969 F: FnMut(&mut MirScalarExpr),
1970 {
1971 self.try_visit_scalars_mut(&mut |s| {
1972 f(s);
1973 Ok::<_, RecursionLimitError>(())
1974 })
1975 .expect("Unexpected error in `visit_scalars_mut` call");
1976 }
1977
1978 /// Fallible visitor for the [`MirScalarExpr`]s directly owned by this relation expression.
1979 ///
1980 /// The `f` visitor should not recursively descend into owned [`MirRelationExpr`]s.
1981 pub fn try_visit_scalars_1<F, E>(&self, f: &mut F) -> Result<(), E>
1982 where
1983 F: FnMut(&MirScalarExpr) -> Result<(), E>,
1984 {
1985 use MirRelationExpr::*;
1986 match self {
1987 Map { scalars, .. } => {
1988 for s in scalars {
1989 f(s)?;
1990 }
1991 }
1992 Filter { predicates, .. } => {
1993 for p in predicates {
1994 f(p)?;
1995 }
1996 }
1997 FlatMap { exprs, .. } => {
1998 for expr in exprs {
1999 f(expr)?;
2000 }
2001 }
2002 Join {
2003 inputs: _,
2004 equivalences,
2005 implementation,
2006 } => {
2007 for equivalence in equivalences {
2008 for expr in equivalence {
2009 f(expr)?;
2010 }
2011 }
2012 match implementation {
2013 JoinImplementation::Differential((_, start_key, _), order) => {
2014 if let Some(start_key) = start_key {
2015 for k in start_key {
2016 f(k)?;
2017 }
2018 }
2019 for (_, lookup_key, _) in order {
2020 for k in lookup_key {
2021 f(k)?;
2022 }
2023 }
2024 }
2025 JoinImplementation::DeltaQuery(paths) => {
2026 for path in paths {
2027 for (_, lookup_key, _) in path {
2028 for k in lookup_key {
2029 f(k)?;
2030 }
2031 }
2032 }
2033 }
2034 JoinImplementation::IndexedFilter(_coll_id, _idx_id, index_key, _) => {
2035 for k in index_key {
2036 f(k)?;
2037 }
2038 }
2039 JoinImplementation::Unimplemented => {} // No scalar exprs
2040 }
2041 }
2042 ArrangeBy { keys, .. } => {
2043 for key in keys {
2044 for s in key {
2045 f(s)?;
2046 }
2047 }
2048 }
2049 Reduce {
2050 group_key,
2051 aggregates,
2052 ..
2053 } => {
2054 for s in group_key {
2055 f(s)?;
2056 }
2057 for agg in aggregates {
2058 f(&agg.expr)?;
2059 }
2060 }
2061 TopK { limit, .. } => {
2062 if let Some(s) = limit {
2063 f(s)?;
2064 }
2065 }
2066 Constant { .. }
2067 | Get { .. }
2068 | Let { .. }
2069 | LetRec { .. }
2070 | Project { .. }
2071 | Negate { .. }
2072 | Threshold { .. }
2073 | Union { .. } => (),
2074 }
2075 Ok(())
2076 }
2077
2078 /// Fallible immutable visitor for the [`MirScalarExpr`]s in the [`MirRelationExpr`] subtree
2079 /// rooted at `self`.
2080 ///
2081 /// Note that this does not recurse into [`MirRelationExpr`] subtrees within [`MirScalarExpr`]
2082 /// nodes.
2083 pub fn try_visit_scalars<F, E>(&self, f: &mut F) -> Result<(), E>
2084 where
2085 F: FnMut(&MirScalarExpr) -> Result<(), E>,
2086 E: From<RecursionLimitError>,
2087 {
2088 self.try_visit_post(&mut |expr| expr.try_visit_scalars_1(f))
2089 }
2090
2091 /// Infallible immutable visitor for the [`MirScalarExpr`]s in the [`MirRelationExpr`] subtree
2092 /// rooted at `self`.
2093 ///
2094 /// Note that this does not recurse into [`MirRelationExpr`] subtrees within [`MirScalarExpr`]
2095 /// nodes.
2096 pub fn visit_scalars<F>(&self, f: &mut F)
2097 where
2098 F: FnMut(&MirScalarExpr),
2099 {
2100 self.try_visit_scalars(&mut |s| {
2101 f(s);
2102 Ok::<_, RecursionLimitError>(())
2103 })
2104 .expect("Unexpected error in `visit_scalars` call");
2105 }
2106
2107 /// Clears the contents of `self` even if it's so deep that simply dropping it would cause a
2108 /// stack overflow in `drop_in_place`.
2109 ///
2110 /// Leaves `self` in an unusable state, so this should only be used if `self` is about to be
2111 /// dropped or otherwise overwritten.
2112 pub fn destroy_carefully(&mut self) {
2113 let mut todo = vec![self.take_dangerous()];
2114 while let Some(mut expr) = todo.pop() {
2115 for child in expr.children_mut() {
2116 todo.push(child.take_dangerous());
2117 }
2118 }
2119 }
2120
2121 /// Computes the size (total number of nodes) and maximum depth of a MirRelationExpr for
2122 /// debug printing purposes.
2123 pub fn debug_size_and_depth(&self) -> (usize, usize) {
2124 let mut size = 0;
2125 let mut max_depth = 0;
2126 let mut todo = vec![(self, 1)];
2127 while let Some((expr, depth)) = todo.pop() {
2128 size += 1;
2129 max_depth = max(max_depth, depth);
2130 todo.extend(expr.children().map(|c| (c, depth + 1)));
2131 }
2132 (size, max_depth)
2133 }
2134
2135 /// The MirRelationExpr is considered potentially expensive if and only if
2136 /// at least one of the following conditions is true:
2137 ///
2138 /// - It contains at least one FlatMap or a Reduce operator.
2139 /// - It contains at least one MirScalarExpr with a function call.
2140 ///
2141 /// !!!WARNING!!!: this method has an HirRelationExpr counterpart. The two
2142 /// should be kept in sync w.r.t. HIR ⇒ MIR lowering!
2143 pub fn could_run_expensive_function(&self) -> bool {
2144 let mut result = false;
2145 self.visit_pre(|e: &MirRelationExpr| {
2146 use MirRelationExpr::*;
2147 use MirScalarExpr::*;
2148 if let Err(_) = self.try_visit_scalars::<_, RecursionLimitError>(&mut |scalar| {
2149 result |= match scalar {
2150 Column(_, _) | Literal(_, _) | CallUnmaterializable(_) | If { .. } => false,
2151 // Function calls are considered expensive
2152 CallUnary { .. } | CallBinary { .. } | CallVariadic { .. } => true,
2153 };
2154 Ok(())
2155 }) {
2156 // Conservatively set `true` if on RecursionLimitError.
2157 result = true;
2158 }
2159 // FlatMap has a table function; Reduce has an aggregate function.
2160 // Other constructs use MirScalarExpr to run a function
2161 result |= matches!(e, FlatMap { .. } | Reduce { .. });
2162 });
2163 result
2164 }
2165
2166 /// Hash to an u64 using Rust's default Hasher. (Which is a somewhat slower, but better Hasher
2167 /// than what `Hashable::hashed` would give us.)
2168 pub fn hash_to_u64(&self) -> u64 {
2169 let mut h = DefaultHasher::new();
2170 self.hash(&mut h);
2171 h.finish()
2172 }
2173}
2174
2175// `LetRec` helpers
2176impl MirRelationExpr {
2177 /// True when `expr` contains a `LetRec` AST node.
2178 pub fn is_recursive(self: &MirRelationExpr) -> bool {
2179 let mut worklist = vec![self];
2180 while let Some(expr) = worklist.pop() {
2181 if let MirRelationExpr::LetRec { .. } = expr {
2182 return true;
2183 }
2184 worklist.extend(expr.children());
2185 }
2186 false
2187 }
2188
2189 /// Return the number of sub-expressions in the tree (including self).
2190 pub fn size(&self) -> usize {
2191 let mut size = 0;
2192 self.visit_pre(|_| size += 1);
2193 size
2194 }
2195
2196 /// Given the ids and values of a LetRec, it computes the subset of ids that are used across
2197 /// iterations. These are those ids that have a reference before they are defined, when reading
2198 /// all the bindings in order.
2199 ///
2200 /// For example:
2201 /// ```SQL
2202 /// WITH MUTUALLY RECURSIVE
2203 /// x(...) AS f(z),
2204 /// y(...) AS g(x),
2205 /// z(...) AS h(y)
2206 /// ...;
2207 /// ```
2208 /// Here, only `z` is returned, because `x` and `y` are referenced only within the same
2209 /// iteration.
2210 ///
2211 /// Note that if a binding references itself, that is also returned.
2212 pub fn recursive_ids(ids: &[LocalId], values: &[MirRelationExpr]) -> BTreeSet<LocalId> {
2213 let mut used_across_iterations = BTreeSet::new();
2214 let mut defined = BTreeSet::new();
2215 for (binding_id, value) in itertools::zip_eq(ids.iter(), values.iter()) {
2216 value.visit_pre(|expr| {
2217 if let MirRelationExpr::Get {
2218 id: Local(get_id), ..
2219 } = expr
2220 {
2221 // If we haven't seen a definition for it yet, then this will refer
2222 // to the previous iteration.
2223 // The `ids.contains` part of the condition is needed to exclude
2224 // those ids that are not really in this LetRec, but either an inner
2225 // or outer one.
2226 if !defined.contains(get_id) && ids.contains(get_id) {
2227 used_across_iterations.insert(*get_id);
2228 }
2229 }
2230 });
2231 defined.insert(*binding_id);
2232 }
2233 used_across_iterations
2234 }
2235
2236 /// Replaces `LetRec` nodes with a stack of `Let` nodes.
2237 ///
2238 /// In each `Let` binding, uses of `Get` in `value` that are not at strictly greater
2239 /// identifiers are rewritten to be the constant collection.
2240 /// This makes the computation perform exactly "one" iteration.
2241 ///
2242 /// This was used only temporarily while developing `LetRec`.
2243 pub fn make_nonrecursive(self: &mut MirRelationExpr) {
2244 let mut deadlist = BTreeSet::new();
2245 let mut worklist = vec![self];
2246 while let Some(expr) = worklist.pop() {
2247 if let MirRelationExpr::LetRec {
2248 ids,
2249 values,
2250 limits: _,
2251 body,
2252 } = expr
2253 {
2254 let ids_values = values
2255 .drain(..)
2256 .zip(ids)
2257 .map(|(value, id)| (*id, value))
2258 .collect::<Vec<_>>();
2259 *expr = body.take_dangerous();
2260 for (id, mut value) in ids_values.into_iter().rev() {
2261 // Remove references to potentially recursive identifiers.
2262 deadlist.insert(id);
2263 value.visit_pre_mut(|e| {
2264 if let MirRelationExpr::Get {
2265 id: crate::Id::Local(id),
2266 typ,
2267 ..
2268 } = e
2269 {
2270 let typ = typ.clone();
2271 if deadlist.contains(id) {
2272 e.take_safely(Some(typ));
2273 }
2274 }
2275 });
2276 *expr = MirRelationExpr::Let {
2277 id,
2278 value: Box::new(value),
2279 body: Box::new(expr.take_dangerous()),
2280 };
2281 }
2282 worklist.push(expr);
2283 } else {
2284 worklist.extend(expr.children_mut().rev());
2285 }
2286 }
2287 }
2288
2289 /// For each Id `id'` referenced in `expr`, if it is larger or equal than `id`, then record in
2290 /// `expire_whens` that when `id'` is redefined, then we should expire the information that
2291 /// we are holding about `id`. Call `do_expirations` with `expire_whens` at each Id
2292 /// redefinition.
2293 ///
2294 /// IMPORTANT: Relies on the numbering of Ids to be what `renumber_bindings` gives.
2295 pub fn collect_expirations(
2296 id: LocalId,
2297 expr: &MirRelationExpr,
2298 expire_whens: &mut BTreeMap<LocalId, Vec<LocalId>>,
2299 ) {
2300 expr.visit_pre(|e| {
2301 if let MirRelationExpr::Get {
2302 id: Id::Local(referenced_id),
2303 ..
2304 } = e
2305 {
2306 // The following check needs `renumber_bindings` to have run recently
2307 if referenced_id >= &id {
2308 expire_whens
2309 .entry(*referenced_id)
2310 .or_insert_with(Vec::new)
2311 .push(id);
2312 }
2313 }
2314 });
2315 }
2316
2317 /// Call this function when `id` is redefined. It modifies `id_infos` by removing information
2318 /// about such Ids whose information depended on the earlier definition of `id`, according to
2319 /// `expire_whens`. Also modifies `expire_whens`: it removes the currently processed entry.
2320 pub fn do_expirations<I>(
2321 redefined_id: LocalId,
2322 expire_whens: &mut BTreeMap<LocalId, Vec<LocalId>>,
2323 id_infos: &mut BTreeMap<LocalId, I>,
2324 ) -> Vec<(LocalId, I)> {
2325 let mut expired_infos = Vec::new();
2326 if let Some(expirations) = expire_whens.remove(&redefined_id) {
2327 for expired_id in expirations.into_iter() {
2328 if let Some(offer) = id_infos.remove(&expired_id) {
2329 expired_infos.push((expired_id, offer));
2330 }
2331 }
2332 }
2333 expired_infos
2334 }
2335}
2336/// Augment non-nullability of columns, by observing either
2337/// 1. Predicates that explicitly test for null values, and
2338/// 2. Columns that if null would make a predicate be null.
2339pub fn non_nullable_columns(predicates: &[MirScalarExpr]) -> BTreeSet<usize> {
2340 let mut nonnull_required_columns = BTreeSet::new();
2341 for predicate in predicates {
2342 // Add any columns that being null would force the predicate to be null.
2343 // Should that happen, the row would be discarded.
2344 predicate.non_null_requirements(&mut nonnull_required_columns);
2345
2346 /*
2347 Test for explicit checks that a column is non-null.
2348
2349 This analysis is ad hoc, and will miss things:
2350
2351 materialize=> create table a(x int, y int);
2352 CREATE TABLE
2353 materialize=> explain with(types) select x from a where (y=x and y is not null) or x is not null;
2354 Optimized Plan
2355 --------------------------------------------------------------------------------------------------------
2356 Explained Query: +
2357 Project (#0) // { types: "(integer?)" } +
2358 Filter ((#0) IS NOT NULL OR ((#1) IS NOT NULL AND (#0 = #1))) // { types: "(integer?, integer?)" }+
2359 Get materialize.public.a // { types: "(integer?, integer?)" } +
2360 +
2361 Source materialize.public.a +
2362 filter=(((#0) IS NOT NULL OR ((#1) IS NOT NULL AND (#0 = #1)))) +
2363
2364 (1 row)
2365 */
2366
2367 if let MirScalarExpr::CallUnary {
2368 func: UnaryFunc::Not(scalar_func::Not),
2369 expr,
2370 } = predicate
2371 {
2372 if let MirScalarExpr::CallUnary {
2373 func: UnaryFunc::IsNull(scalar_func::IsNull),
2374 expr,
2375 } = &**expr
2376 {
2377 if let MirScalarExpr::Column(c, _name) = &**expr {
2378 nonnull_required_columns.insert(*c);
2379 }
2380 }
2381 }
2382 }
2383
2384 nonnull_required_columns
2385}
2386
2387impl CollectionPlan for MirRelationExpr {
2388 // !!!WARNING!!!: this method has an HirRelationExpr counterpart. The two
2389 // should be kept in sync w.r.t. HIR ⇒ MIR lowering!
2390 fn depends_on_into(&self, out: &mut BTreeSet<GlobalId>) {
2391 if let MirRelationExpr::Get {
2392 id: Id::Global(id), ..
2393 } = self
2394 {
2395 out.insert(*id);
2396 }
2397 self.visit_children(|expr| expr.depends_on_into(out))
2398 }
2399}
2400
2401impl MirRelationExpr {
2402 /// Iterates through references to child expressions.
2403 pub fn children(&self) -> impl DoubleEndedIterator<Item = &Self> {
2404 let mut first = None;
2405 let mut second = None;
2406 let mut rest = None;
2407 let mut last = None;
2408
2409 use MirRelationExpr::*;
2410 match self {
2411 Constant { .. } | Get { .. } => (),
2412 Let { value, body, .. } => {
2413 first = Some(&**value);
2414 second = Some(&**body);
2415 }
2416 LetRec { values, body, .. } => {
2417 rest = Some(values);
2418 last = Some(&**body);
2419 }
2420 Project { input, .. }
2421 | Map { input, .. }
2422 | FlatMap { input, .. }
2423 | Filter { input, .. }
2424 | Reduce { input, .. }
2425 | TopK { input, .. }
2426 | Negate { input }
2427 | Threshold { input }
2428 | ArrangeBy { input, .. } => {
2429 first = Some(&**input);
2430 }
2431 Join { inputs, .. } => {
2432 rest = Some(inputs);
2433 }
2434 Union { base, inputs } => {
2435 first = Some(&**base);
2436 rest = Some(inputs);
2437 }
2438 }
2439
2440 first
2441 .into_iter()
2442 .chain(second)
2443 .chain(rest.into_iter().flatten())
2444 .chain(last)
2445 }
2446
2447 /// Iterates through mutable references to child expressions.
2448 pub fn children_mut(&mut self) -> impl DoubleEndedIterator<Item = &mut Self> {
2449 let mut first = None;
2450 let mut second = None;
2451 let mut rest = None;
2452 let mut last = None;
2453
2454 use MirRelationExpr::*;
2455 match self {
2456 Constant { .. } | Get { .. } => (),
2457 Let { value, body, .. } => {
2458 first = Some(&mut **value);
2459 second = Some(&mut **body);
2460 }
2461 LetRec { values, body, .. } => {
2462 rest = Some(values);
2463 last = Some(&mut **body);
2464 }
2465 Project { input, .. }
2466 | Map { input, .. }
2467 | FlatMap { input, .. }
2468 | Filter { input, .. }
2469 | Reduce { input, .. }
2470 | TopK { input, .. }
2471 | Negate { input }
2472 | Threshold { input }
2473 | ArrangeBy { input, .. } => {
2474 first = Some(&mut **input);
2475 }
2476 Join { inputs, .. } => {
2477 rest = Some(inputs);
2478 }
2479 Union { base, inputs } => {
2480 first = Some(&mut **base);
2481 rest = Some(inputs);
2482 }
2483 }
2484
2485 first
2486 .into_iter()
2487 .chain(second)
2488 .chain(rest.into_iter().flatten())
2489 .chain(last)
2490 }
2491
2492 /// Iterative pre-order visitor.
2493 pub fn visit_pre<'a, F: FnMut(&'a Self)>(&'a self, mut f: F) {
2494 let mut worklist = vec![self];
2495 while let Some(expr) = worklist.pop() {
2496 f(expr);
2497 worklist.extend(expr.children().rev());
2498 }
2499 }
2500
2501 /// Iterative pre-order visitor.
2502 pub fn visit_pre_mut<F: FnMut(&mut Self)>(&mut self, mut f: F) {
2503 let mut worklist = vec![self];
2504 while let Some(expr) = worklist.pop() {
2505 f(expr);
2506 worklist.extend(expr.children_mut().rev());
2507 }
2508 }
2509
2510 /// Return a vector of references to the subtrees of this expression
2511 /// in post-visit order (the last element is `&self`).
2512 pub fn post_order_vec(&self) -> Vec<&Self> {
2513 let mut stack = vec![self];
2514 let mut result = vec![];
2515 while let Some(expr) = stack.pop() {
2516 result.push(expr);
2517 stack.extend(expr.children());
2518 }
2519 result.reverse();
2520 result
2521 }
2522}
2523
2524impl VisitChildren<Self> for MirRelationExpr {
2525 fn visit_children<F>(&self, mut f: F)
2526 where
2527 F: FnMut(&Self),
2528 {
2529 for child in self.children() {
2530 f(child)
2531 }
2532 }
2533
2534 fn visit_mut_children<F>(&mut self, mut f: F)
2535 where
2536 F: FnMut(&mut Self),
2537 {
2538 for child in self.children_mut() {
2539 f(child)
2540 }
2541 }
2542
2543 fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
2544 where
2545 F: FnMut(&Self) -> Result<(), E>,
2546 E: From<RecursionLimitError>,
2547 {
2548 for child in self.children() {
2549 f(child)?
2550 }
2551 Ok(())
2552 }
2553
2554 fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
2555 where
2556 F: FnMut(&mut Self) -> Result<(), E>,
2557 E: From<RecursionLimitError>,
2558 {
2559 for child in self.children_mut() {
2560 f(child)?
2561 }
2562 Ok(())
2563 }
2564}
2565
2566/// Specification for an ordering by a column.
2567#[derive(
2568 Arbitrary,
2569 Debug,
2570 Clone,
2571 Copy,
2572 Eq,
2573 PartialEq,
2574 Ord,
2575 PartialOrd,
2576 Serialize,
2577 Deserialize,
2578 Hash,
2579 MzReflect,
2580)]
2581pub struct ColumnOrder {
2582 /// The column index.
2583 pub column: usize,
2584 /// Whether to sort in descending order.
2585 #[serde(default)]
2586 pub desc: bool,
2587 /// Whether to sort nulls last.
2588 #[serde(default)]
2589 pub nulls_last: bool,
2590}
2591
2592impl Columnation for ColumnOrder {
2593 type InnerRegion = CopyRegion<Self>;
2594}
2595
2596impl RustType<ProtoColumnOrder> for ColumnOrder {
2597 fn into_proto(&self) -> ProtoColumnOrder {
2598 ProtoColumnOrder {
2599 column: self.column.into_proto(),
2600 desc: self.desc,
2601 nulls_last: self.nulls_last,
2602 }
2603 }
2604
2605 fn from_proto(proto: ProtoColumnOrder) -> Result<Self, TryFromProtoError> {
2606 Ok(ColumnOrder {
2607 column: proto.column.into_rust()?,
2608 desc: proto.desc,
2609 nulls_last: proto.nulls_last,
2610 })
2611 }
2612}
2613
2614impl<'a, M> fmt::Display for HumanizedExpr<'a, ColumnOrder, M>
2615where
2616 M: HumanizerMode,
2617{
2618 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
2619 // If you modify this, then please also attend to Display for ColumnOrderWithExpr!
2620 write!(
2621 f,
2622 "{} {} {}",
2623 self.child(&self.expr.column),
2624 if self.expr.desc { "desc" } else { "asc" },
2625 if self.expr.nulls_last {
2626 "nulls_last"
2627 } else {
2628 "nulls_first"
2629 },
2630 )
2631 }
2632}
2633
2634/// Describes an aggregation expression.
2635#[derive(
2636 Arbitrary, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize, Hash, MzReflect,
2637)]
2638pub struct AggregateExpr {
2639 /// Names the aggregation function.
2640 pub func: AggregateFunc,
2641 /// An expression which extracts from each row the input to `func`.
2642 pub expr: MirScalarExpr,
2643 /// Should the aggregation be applied only to distinct results in each group.
2644 #[serde(default)]
2645 pub distinct: bool,
2646}
2647
2648impl RustType<ProtoAggregateExpr> for AggregateExpr {
2649 fn into_proto(&self) -> ProtoAggregateExpr {
2650 ProtoAggregateExpr {
2651 func: Some(self.func.into_proto()),
2652 expr: Some(self.expr.into_proto()),
2653 distinct: self.distinct,
2654 }
2655 }
2656
2657 fn from_proto(proto: ProtoAggregateExpr) -> Result<Self, TryFromProtoError> {
2658 Ok(Self {
2659 func: proto.func.into_rust_if_some("ProtoAggregateExpr::func")?,
2660 expr: proto.expr.into_rust_if_some("ProtoAggregateExpr::expr")?,
2661 distinct: proto.distinct,
2662 })
2663 }
2664}
2665
2666impl AggregateExpr {
2667 /// Computes the type of this `AggregateExpr`.
2668 pub fn typ(&self, column_types: &[ColumnType]) -> ColumnType {
2669 self.func.output_type(self.expr.typ(column_types))
2670 }
2671
2672 /// Returns whether the expression has a constant result.
2673 pub fn is_constant(&self) -> bool {
2674 match self.func {
2675 AggregateFunc::MaxInt16
2676 | AggregateFunc::MaxInt32
2677 | AggregateFunc::MaxInt64
2678 | AggregateFunc::MaxUInt16
2679 | AggregateFunc::MaxUInt32
2680 | AggregateFunc::MaxUInt64
2681 | AggregateFunc::MaxMzTimestamp
2682 | AggregateFunc::MaxFloat32
2683 | AggregateFunc::MaxFloat64
2684 | AggregateFunc::MaxBool
2685 | AggregateFunc::MaxString
2686 | AggregateFunc::MaxDate
2687 | AggregateFunc::MaxTimestamp
2688 | AggregateFunc::MaxTimestampTz
2689 | AggregateFunc::MinInt16
2690 | AggregateFunc::MinInt32
2691 | AggregateFunc::MinInt64
2692 | AggregateFunc::MinUInt16
2693 | AggregateFunc::MinUInt32
2694 | AggregateFunc::MinUInt64
2695 | AggregateFunc::MinMzTimestamp
2696 | AggregateFunc::MinFloat32
2697 | AggregateFunc::MinFloat64
2698 | AggregateFunc::MinBool
2699 | AggregateFunc::MinString
2700 | AggregateFunc::MinDate
2701 | AggregateFunc::MinTimestamp
2702 | AggregateFunc::MinTimestampTz
2703 | AggregateFunc::Any
2704 | AggregateFunc::All
2705 | AggregateFunc::Dummy => self.expr.is_literal(),
2706 AggregateFunc::Count => self.expr.is_literal_null(),
2707 _ => self.expr.is_literal_err(),
2708 }
2709 }
2710
2711 /// Returns an expression that computes `self` on a group that has exactly one row.
2712 /// Instead of performing a `Reduce` with `self`, one can perform a `Map` with the expression
2713 /// returned by `on_unique`, which is cheaper. (See `ReduceElision`.)
2714 pub fn on_unique(&self, input_type: &[ColumnType]) -> MirScalarExpr {
2715 match &self.func {
2716 // Count is one if non-null, and zero if null.
2717 AggregateFunc::Count => self
2718 .expr
2719 .clone()
2720 .call_unary(UnaryFunc::IsNull(crate::func::IsNull))
2721 .if_then_else(
2722 MirScalarExpr::literal_ok(Datum::Int64(0), ScalarType::Int64),
2723 MirScalarExpr::literal_ok(Datum::Int64(1), ScalarType::Int64),
2724 ),
2725
2726 // SumInt16 takes Int16s as input, but outputs Int64s.
2727 AggregateFunc::SumInt16 => self
2728 .expr
2729 .clone()
2730 .call_unary(UnaryFunc::CastInt16ToInt64(scalar_func::CastInt16ToInt64)),
2731
2732 // SumInt32 takes Int32s as input, but outputs Int64s.
2733 AggregateFunc::SumInt32 => self
2734 .expr
2735 .clone()
2736 .call_unary(UnaryFunc::CastInt32ToInt64(scalar_func::CastInt32ToInt64)),
2737
2738 // SumInt64 takes Int64s as input, but outputs numerics.
2739 AggregateFunc::SumInt64 => self.expr.clone().call_unary(UnaryFunc::CastInt64ToNumeric(
2740 scalar_func::CastInt64ToNumeric(Some(NumericMaxScale::ZERO)),
2741 )),
2742
2743 // SumUInt16 takes UInt16s as input, but outputs UInt64s.
2744 AggregateFunc::SumUInt16 => self.expr.clone().call_unary(
2745 UnaryFunc::CastUint16ToUint64(scalar_func::CastUint16ToUint64),
2746 ),
2747
2748 // SumUInt32 takes UInt32s as input, but outputs UInt64s.
2749 AggregateFunc::SumUInt32 => self.expr.clone().call_unary(
2750 UnaryFunc::CastUint32ToUint64(scalar_func::CastUint32ToUint64),
2751 ),
2752
2753 // SumUInt64 takes UInt64s as input, but outputs numerics.
2754 AggregateFunc::SumUInt64 => {
2755 self.expr.clone().call_unary(UnaryFunc::CastUint64ToNumeric(
2756 scalar_func::CastUint64ToNumeric(Some(NumericMaxScale::ZERO)),
2757 ))
2758 }
2759
2760 // JsonbAgg takes _anything_ as input, but must output a Jsonb array.
2761 AggregateFunc::JsonbAgg { .. } => MirScalarExpr::CallVariadic {
2762 func: VariadicFunc::JsonbBuildArray,
2763 exprs: vec![
2764 self.expr
2765 .clone()
2766 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0))),
2767 ],
2768 },
2769
2770 // JsonbAgg takes _anything_ as input, but must output a Jsonb object.
2771 AggregateFunc::JsonbObjectAgg { .. } => {
2772 let record = self
2773 .expr
2774 .clone()
2775 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2776 MirScalarExpr::CallVariadic {
2777 func: VariadicFunc::JsonbBuildObject,
2778 exprs: (0..2)
2779 .map(|i| {
2780 record
2781 .clone()
2782 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(i)))
2783 })
2784 .collect(),
2785 }
2786 }
2787
2788 AggregateFunc::MapAgg { value_type, .. } => {
2789 let record = self
2790 .expr
2791 .clone()
2792 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2793 MirScalarExpr::CallVariadic {
2794 func: VariadicFunc::MapBuild {
2795 value_type: value_type.clone(),
2796 },
2797 exprs: (0..2)
2798 .map(|i| {
2799 record
2800 .clone()
2801 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(i)))
2802 })
2803 .collect(),
2804 }
2805 }
2806
2807 // StringAgg takes nested records of strings and outputs a string
2808 AggregateFunc::StringAgg { .. } => self
2809 .expr
2810 .clone()
2811 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)))
2812 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0))),
2813
2814 // ListConcat and ArrayConcat take a single level of records and output a list containing exactly 1 element
2815 AggregateFunc::ListConcat { .. } | AggregateFunc::ArrayConcat { .. } => self
2816 .expr
2817 .clone()
2818 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0))),
2819
2820 // RowNumber, Rank, DenseRank take a list of records and output a list containing exactly 1 element
2821 AggregateFunc::RowNumber { .. } => {
2822 self.on_unique_ranking_window_funcs(input_type, "?row_number?")
2823 }
2824 AggregateFunc::Rank { .. } => self.on_unique_ranking_window_funcs(input_type, "?rank?"),
2825 AggregateFunc::DenseRank { .. } => {
2826 self.on_unique_ranking_window_funcs(input_type, "?dense_rank?")
2827 }
2828
2829 // The input type for LagLead is ((OriginalRow, (InputValue, Offset, Default)), OrderByExprs...)
2830 AggregateFunc::LagLead { lag_lead, .. } => {
2831 let tuple = self
2832 .expr
2833 .clone()
2834 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2835
2836 // Get the overall return type
2837 let return_type_with_orig_row = self
2838 .typ(input_type)
2839 .scalar_type
2840 .unwrap_list_element_type()
2841 .clone();
2842 let lag_lead_return_type =
2843 return_type_with_orig_row.unwrap_record_element_type()[0].clone();
2844
2845 // Extract the original row
2846 let original_row = tuple
2847 .clone()
2848 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2849
2850 // Extract the encoded args
2851 let encoded_args =
2852 tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
2853
2854 let (result_expr, column_name) =
2855 Self::on_unique_lag_lead(lag_lead, encoded_args, lag_lead_return_type.clone());
2856
2857 MirScalarExpr::CallVariadic {
2858 func: VariadicFunc::ListCreate {
2859 elem_type: return_type_with_orig_row,
2860 },
2861 exprs: vec![MirScalarExpr::CallVariadic {
2862 func: VariadicFunc::RecordCreate {
2863 field_names: vec![column_name, ColumnName::from("?record?")],
2864 },
2865 exprs: vec![result_expr, original_row],
2866 }],
2867 }
2868 }
2869
2870 // The input type for FirstValue is ((OriginalRow, InputValue), OrderByExprs...)
2871 AggregateFunc::FirstValue { window_frame, .. } => {
2872 let tuple = self
2873 .expr
2874 .clone()
2875 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2876
2877 // Get the overall return type
2878 let return_type_with_orig_row = self
2879 .typ(input_type)
2880 .scalar_type
2881 .unwrap_list_element_type()
2882 .clone();
2883 let first_value_return_type =
2884 return_type_with_orig_row.unwrap_record_element_type()[0].clone();
2885
2886 // Extract the original row
2887 let original_row = tuple
2888 .clone()
2889 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2890
2891 // Extract the input value
2892 let arg = tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
2893
2894 let (result_expr, column_name) = Self::on_unique_first_value_last_value(
2895 window_frame,
2896 arg,
2897 first_value_return_type,
2898 );
2899
2900 MirScalarExpr::CallVariadic {
2901 func: VariadicFunc::ListCreate {
2902 elem_type: return_type_with_orig_row,
2903 },
2904 exprs: vec![MirScalarExpr::CallVariadic {
2905 func: VariadicFunc::RecordCreate {
2906 field_names: vec![column_name, ColumnName::from("?record?")],
2907 },
2908 exprs: vec![result_expr, original_row],
2909 }],
2910 }
2911 }
2912
2913 // The input type for LastValue is ((OriginalRow, InputValue), OrderByExprs...)
2914 AggregateFunc::LastValue { window_frame, .. } => {
2915 let tuple = self
2916 .expr
2917 .clone()
2918 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2919
2920 // Get the overall return type
2921 let return_type_with_orig_row = self
2922 .typ(input_type)
2923 .scalar_type
2924 .unwrap_list_element_type()
2925 .clone();
2926 let last_value_return_type =
2927 return_type_with_orig_row.unwrap_record_element_type()[0].clone();
2928
2929 // Extract the original row
2930 let original_row = tuple
2931 .clone()
2932 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2933
2934 // Extract the input value
2935 let arg = tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
2936
2937 let (result_expr, column_name) = Self::on_unique_first_value_last_value(
2938 window_frame,
2939 arg,
2940 last_value_return_type,
2941 );
2942
2943 MirScalarExpr::CallVariadic {
2944 func: VariadicFunc::ListCreate {
2945 elem_type: return_type_with_orig_row,
2946 },
2947 exprs: vec![MirScalarExpr::CallVariadic {
2948 func: VariadicFunc::RecordCreate {
2949 field_names: vec![column_name, ColumnName::from("?record?")],
2950 },
2951 exprs: vec![result_expr, original_row],
2952 }],
2953 }
2954 }
2955
2956 // The input type for window aggs is ((OriginalRow, InputValue), OrderByExprs...)
2957 // See an example MIR in `window_func_applied_to`.
2958 AggregateFunc::WindowAggregate {
2959 wrapped_aggregate,
2960 window_frame,
2961 order_by: _,
2962 } => {
2963 // TODO: deduplicate code between the various window function cases.
2964
2965 let tuple = self
2966 .expr
2967 .clone()
2968 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2969
2970 // Get the overall return type
2971 let return_type = self
2972 .typ(input_type)
2973 .scalar_type
2974 .unwrap_list_element_type()
2975 .clone();
2976 let window_agg_return_type = return_type.unwrap_record_element_type()[0].clone();
2977
2978 // Extract the original row
2979 let original_row = tuple
2980 .clone()
2981 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2982
2983 // Extract the input value
2984 let arg_expr = tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
2985
2986 let (result, column_name) = Self::on_unique_window_agg(
2987 window_frame,
2988 arg_expr,
2989 input_type,
2990 window_agg_return_type,
2991 wrapped_aggregate,
2992 );
2993
2994 MirScalarExpr::CallVariadic {
2995 func: VariadicFunc::ListCreate {
2996 elem_type: return_type,
2997 },
2998 exprs: vec![MirScalarExpr::CallVariadic {
2999 func: VariadicFunc::RecordCreate {
3000 field_names: vec![column_name, ColumnName::from("?record?")],
3001 },
3002 exprs: vec![result, original_row],
3003 }],
3004 }
3005 }
3006
3007 // The input type is ((OriginalRow, (Arg1, Arg2, ...)), OrderByExprs...)
3008 AggregateFunc::FusedWindowAggregate {
3009 wrapped_aggregates,
3010 order_by: _,
3011 window_frame,
3012 } => {
3013 // Throw away OrderByExprs
3014 let tuple = self
3015 .expr
3016 .clone()
3017 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
3018
3019 // Extract the original row
3020 let original_row = tuple
3021 .clone()
3022 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
3023
3024 // Extract the args of the fused call
3025 let all_args = tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
3026
3027 let return_type_with_orig_row = self
3028 .typ(input_type)
3029 .scalar_type
3030 .unwrap_list_element_type()
3031 .clone();
3032
3033 let all_func_return_types =
3034 return_type_with_orig_row.unwrap_record_element_type()[0].clone();
3035 let mut func_result_exprs = Vec::new();
3036 let mut col_names = Vec::new();
3037 for (idx, wrapped_aggr) in wrapped_aggregates.iter().enumerate() {
3038 let arg = all_args
3039 .clone()
3040 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(idx)));
3041 let return_type =
3042 all_func_return_types.unwrap_record_element_type()[idx].clone();
3043 let (result, column_name) = Self::on_unique_window_agg(
3044 window_frame,
3045 arg,
3046 input_type,
3047 return_type,
3048 wrapped_aggr,
3049 );
3050 func_result_exprs.push(result);
3051 col_names.push(column_name);
3052 }
3053
3054 MirScalarExpr::CallVariadic {
3055 func: VariadicFunc::ListCreate {
3056 elem_type: return_type_with_orig_row,
3057 },
3058 exprs: vec![MirScalarExpr::CallVariadic {
3059 func: VariadicFunc::RecordCreate {
3060 field_names: vec![
3061 ColumnName::from("?fused_window_aggr?"),
3062 ColumnName::from("?record?"),
3063 ],
3064 },
3065 exprs: vec![
3066 MirScalarExpr::CallVariadic {
3067 func: VariadicFunc::RecordCreate {
3068 field_names: col_names,
3069 },
3070 exprs: func_result_exprs,
3071 },
3072 original_row,
3073 ],
3074 }],
3075 }
3076 }
3077
3078 // The input type is ((OriginalRow, (Args1, Args2, ...)), OrderByExprs...)
3079 AggregateFunc::FusedValueWindowFunc {
3080 funcs,
3081 order_by: outer_order_by,
3082 } => {
3083 // Throw away OrderByExprs
3084 let tuple = self
3085 .expr
3086 .clone()
3087 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
3088
3089 // Extract the original row
3090 let original_row = tuple
3091 .clone()
3092 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
3093
3094 // Extract the encoded args of the fused call
3095 let all_encoded_args =
3096 tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
3097
3098 let return_type_with_orig_row = self
3099 .typ(input_type)
3100 .scalar_type
3101 .unwrap_list_element_type()
3102 .clone();
3103
3104 let all_func_return_types =
3105 return_type_with_orig_row.unwrap_record_element_type()[0].clone();
3106 let mut func_result_exprs = Vec::new();
3107 let mut col_names = Vec::new();
3108 for (idx, func) in funcs.iter().enumerate() {
3109 let args_for_func = all_encoded_args
3110 .clone()
3111 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(idx)));
3112 let return_type_for_func =
3113 all_func_return_types.unwrap_record_element_type()[idx].clone();
3114 let (result, column_name) = match func {
3115 AggregateFunc::LagLead {
3116 lag_lead,
3117 order_by,
3118 ignore_nulls: _,
3119 } => {
3120 assert_eq!(order_by, outer_order_by);
3121 Self::on_unique_lag_lead(lag_lead, args_for_func, return_type_for_func)
3122 }
3123 AggregateFunc::FirstValue {
3124 window_frame,
3125 order_by,
3126 } => {
3127 assert_eq!(order_by, outer_order_by);
3128 Self::on_unique_first_value_last_value(
3129 window_frame,
3130 args_for_func,
3131 return_type_for_func,
3132 )
3133 }
3134 AggregateFunc::LastValue {
3135 window_frame,
3136 order_by,
3137 } => {
3138 assert_eq!(order_by, outer_order_by);
3139 Self::on_unique_first_value_last_value(
3140 window_frame,
3141 args_for_func,
3142 return_type_for_func,
3143 )
3144 }
3145 _ => panic!("unknown function in FusedValueWindowFunc"),
3146 };
3147 func_result_exprs.push(result);
3148 col_names.push(column_name);
3149 }
3150
3151 MirScalarExpr::CallVariadic {
3152 func: VariadicFunc::ListCreate {
3153 elem_type: return_type_with_orig_row,
3154 },
3155 exprs: vec![MirScalarExpr::CallVariadic {
3156 func: VariadicFunc::RecordCreate {
3157 field_names: vec![
3158 ColumnName::from("?fused_value_window_func?"),
3159 ColumnName::from("?record?"),
3160 ],
3161 },
3162 exprs: vec![
3163 MirScalarExpr::CallVariadic {
3164 func: VariadicFunc::RecordCreate {
3165 field_names: col_names,
3166 },
3167 exprs: func_result_exprs,
3168 },
3169 original_row,
3170 ],
3171 }],
3172 }
3173 }
3174
3175 // All other variants should return the argument to the aggregation.
3176 AggregateFunc::MaxNumeric
3177 | AggregateFunc::MaxInt16
3178 | AggregateFunc::MaxInt32
3179 | AggregateFunc::MaxInt64
3180 | AggregateFunc::MaxUInt16
3181 | AggregateFunc::MaxUInt32
3182 | AggregateFunc::MaxUInt64
3183 | AggregateFunc::MaxMzTimestamp
3184 | AggregateFunc::MaxFloat32
3185 | AggregateFunc::MaxFloat64
3186 | AggregateFunc::MaxBool
3187 | AggregateFunc::MaxString
3188 | AggregateFunc::MaxDate
3189 | AggregateFunc::MaxTimestamp
3190 | AggregateFunc::MaxTimestampTz
3191 | AggregateFunc::MaxInterval
3192 | AggregateFunc::MaxTime
3193 | AggregateFunc::MinNumeric
3194 | AggregateFunc::MinInt16
3195 | AggregateFunc::MinInt32
3196 | AggregateFunc::MinInt64
3197 | AggregateFunc::MinUInt16
3198 | AggregateFunc::MinUInt32
3199 | AggregateFunc::MinUInt64
3200 | AggregateFunc::MinMzTimestamp
3201 | AggregateFunc::MinFloat32
3202 | AggregateFunc::MinFloat64
3203 | AggregateFunc::MinBool
3204 | AggregateFunc::MinString
3205 | AggregateFunc::MinDate
3206 | AggregateFunc::MinTimestamp
3207 | AggregateFunc::MinTimestampTz
3208 | AggregateFunc::MinInterval
3209 | AggregateFunc::MinTime
3210 | AggregateFunc::SumFloat32
3211 | AggregateFunc::SumFloat64
3212 | AggregateFunc::SumNumeric
3213 | AggregateFunc::Any
3214 | AggregateFunc::All
3215 | AggregateFunc::Dummy => self.expr.clone(),
3216 }
3217 }
3218
3219 /// `on_unique` for ROW_NUMBER, RANK, DENSE_RANK
3220 fn on_unique_ranking_window_funcs(
3221 &self,
3222 input_type: &[ColumnType],
3223 col_name: &str,
3224 ) -> MirScalarExpr {
3225 let list = self
3226 .expr
3227 .clone()
3228 // extract the list within the record
3229 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
3230
3231 // extract the expression within the list
3232 let record = MirScalarExpr::CallVariadic {
3233 func: VariadicFunc::ListIndex,
3234 exprs: vec![
3235 list,
3236 MirScalarExpr::literal_ok(Datum::Int64(1), ScalarType::Int64),
3237 ],
3238 };
3239
3240 MirScalarExpr::CallVariadic {
3241 func: VariadicFunc::ListCreate {
3242 elem_type: self
3243 .typ(input_type)
3244 .scalar_type
3245 .unwrap_list_element_type()
3246 .clone(),
3247 },
3248 exprs: vec![MirScalarExpr::CallVariadic {
3249 func: VariadicFunc::RecordCreate {
3250 field_names: vec![ColumnName::from(col_name), ColumnName::from("?record?")],
3251 },
3252 exprs: vec![
3253 MirScalarExpr::literal_ok(Datum::Int64(1), ScalarType::Int64),
3254 record,
3255 ],
3256 }],
3257 }
3258 }
3259
3260 /// `on_unique` for `lag` and `lead`
3261 fn on_unique_lag_lead(
3262 lag_lead: &LagLeadType,
3263 encoded_args: MirScalarExpr,
3264 return_type: ScalarType,
3265 ) -> (MirScalarExpr, ColumnName) {
3266 let expr = encoded_args
3267 .clone()
3268 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
3269 let offset = encoded_args
3270 .clone()
3271 .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
3272 let default_value =
3273 encoded_args.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(2)));
3274
3275 // In this case, the window always has only one element, so if the offset is not null and
3276 // not zero, the default value should be returned instead.
3277 let value = offset
3278 .clone()
3279 .call_binary(
3280 MirScalarExpr::literal_ok(Datum::Int32(0), ScalarType::Int32),
3281 crate::BinaryFunc::Eq,
3282 )
3283 .if_then_else(expr, default_value);
3284 let result_expr = offset
3285 .call_unary(UnaryFunc::IsNull(crate::func::IsNull))
3286 .if_then_else(MirScalarExpr::literal_null(return_type), value);
3287
3288 let column_name = ColumnName::from(match lag_lead {
3289 LagLeadType::Lag => "?lag?",
3290 LagLeadType::Lead => "?lead?",
3291 });
3292
3293 (result_expr, column_name)
3294 }
3295
3296 /// `on_unique` for `first_value` and `last_value`
3297 fn on_unique_first_value_last_value(
3298 window_frame: &WindowFrame,
3299 arg: MirScalarExpr,
3300 return_type: ScalarType,
3301 ) -> (MirScalarExpr, ColumnName) {
3302 // If the window frame includes the current (single) row, return its value, null otherwise
3303 let result_expr = if window_frame.includes_current_row() {
3304 arg
3305 } else {
3306 MirScalarExpr::literal_null(return_type)
3307 };
3308 (result_expr, ColumnName::from("?first_value?"))
3309 }
3310
3311 /// `on_unique` for window aggregations
3312 fn on_unique_window_agg(
3313 window_frame: &WindowFrame,
3314 arg_expr: MirScalarExpr,
3315 input_type: &[ColumnType],
3316 return_type: ScalarType,
3317 wrapped_aggr: &AggregateFunc,
3318 ) -> (MirScalarExpr, ColumnName) {
3319 // If the window frame includes the current (single) row, evaluate the wrapped aggregate on
3320 // that row. Otherwise, return the default value for the aggregate.
3321 let result_expr = if window_frame.includes_current_row() {
3322 AggregateExpr {
3323 func: wrapped_aggr.clone(),
3324 expr: arg_expr,
3325 distinct: false, // We have just one input element; DISTINCT doesn't matter.
3326 }
3327 .on_unique(input_type)
3328 } else {
3329 MirScalarExpr::literal_ok(wrapped_aggr.default(), return_type)
3330 };
3331 (result_expr, ColumnName::from("?window_agg?"))
3332 }
3333
3334 /// Returns whether the expression is COUNT(*) or not. Note that
3335 /// when we define the count builtin in sql::func, we convert
3336 /// COUNT(*) to COUNT(true), making it indistinguishable from
3337 /// literal COUNT(true), but we prefer to consider this as the
3338 /// former.
3339 ///
3340 /// (HIR has the same `is_count_asterisk`.)
3341 pub fn is_count_asterisk(&self) -> bool {
3342 self.func == AggregateFunc::Count && self.expr.is_literal_true() && !self.distinct
3343 }
3344}
3345
3346/// Describe a join implementation in dataflow.
3347#[derive(
3348 Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize, Hash, MzReflect, Arbitrary,
3349)]
3350pub enum JoinImplementation {
3351 /// Perform a sequence of binary differential dataflow joins.
3352 ///
3353 /// The first argument indicates
3354 /// 1) the index of the starting collection,
3355 /// 2) if it should be arranged, the keys to arrange it by, and
3356 /// 3) the characteristics of the starting collection (for EXPLAINing).
3357 /// The sequence that follows lists other relation indexes, and the key for
3358 /// the arrangement we should use when joining it in.
3359 /// The JoinInputCharacteristics are for EXPLAINing the characteristics that
3360 /// were used for join ordering.
3361 ///
3362 /// Each collection index should occur exactly once, either as the starting collection
3363 /// or somewhere in the list.
3364 Differential(
3365 (
3366 usize,
3367 Option<Vec<MirScalarExpr>>,
3368 Option<JoinInputCharacteristics>,
3369 ),
3370 Vec<(usize, Vec<MirScalarExpr>, Option<JoinInputCharacteristics>)>,
3371 ),
3372 /// Perform independent delta query dataflows for each input.
3373 ///
3374 /// The argument is a sequence of plans, for the input collections in order.
3375 /// Each plan starts from the corresponding index, and then in sequence joins
3376 /// against collections identified by index and with the specified arrangement key.
3377 /// The JoinInputCharacteristics are for EXPLAINing the characteristics that were
3378 /// used for join ordering.
3379 DeltaQuery(Vec<Vec<(usize, Vec<MirScalarExpr>, Option<JoinInputCharacteristics>)>>),
3380 /// Join a user-created index with a constant collection to speed up the evaluation of a
3381 /// predicate such as `(f1 = 3 AND f2 = 5) OR (f1 = 7 AND f2 = 9)`.
3382 /// This gets translated to a Differential join during MIR -> LIR lowering, but we still want
3383 /// to represent it in MIR, because the fast path detection wants to match on this.
3384 ///
3385 /// Consists of (`<coll_id>`, `<index_id>`, `<index_key>`, `<constants>`)
3386 IndexedFilter(
3387 GlobalId,
3388 GlobalId,
3389 Vec<MirScalarExpr>,
3390 #[mzreflect(ignore)] Vec<Row>,
3391 ),
3392 /// No implementation yet selected.
3393 Unimplemented,
3394}
3395
3396impl Default for JoinImplementation {
3397 fn default() -> Self {
3398 JoinImplementation::Unimplemented
3399 }
3400}
3401
3402impl JoinImplementation {
3403 /// Returns `true` iff the value is not [`JoinImplementation::Unimplemented`].
3404 pub fn is_implemented(&self) -> bool {
3405 match self {
3406 Self::Unimplemented => false,
3407 _ => true,
3408 }
3409 }
3410
3411 /// Returns an optional implementation name if the value is not [`JoinImplementation::Unimplemented`].
3412 pub fn name(&self) -> Option<&'static str> {
3413 match self {
3414 Self::Differential(..) => Some("differential"),
3415 Self::DeltaQuery(..) => Some("delta"),
3416 Self::IndexedFilter(..) => Some("indexed_filter"),
3417 Self::Unimplemented => None,
3418 }
3419 }
3420}
3421
3422/// Characteristics of a join order candidate collection.
3423///
3424/// A candidate is described by a collection and a key, and may have various liabilities.
3425/// Primarily, the candidate may risk substantial inflation of records, which is something
3426/// that concerns us greatly. Additionally, the candidate may be unarranged, and we would
3427/// prefer candidates that do not require additional memory. Finally, we prefer lower id
3428/// collections in the interest of consistent tie-breaking. For more characteristics, see
3429/// comments on individual fields.
3430///
3431/// This has more than one version. `new` instantiates the appropriate version based on a
3432/// feature flag.
3433#[derive(
3434 Eq, PartialEq, Ord, PartialOrd, Debug, Clone, Serialize, Deserialize, Hash, MzReflect, Arbitrary,
3435)]
3436pub enum JoinInputCharacteristics {
3437 /// Old version, with `enable_join_prioritize_arranged` turned off.
3438 V1(JoinInputCharacteristicsV1),
3439 /// Newer version, with `enable_join_prioritize_arranged` turned on.
3440 V2(JoinInputCharacteristicsV2),
3441}
3442
3443impl JoinInputCharacteristics {
3444 /// Creates a new instance with the given characteristics.
3445 pub fn new(
3446 unique_key: bool,
3447 key_length: usize,
3448 arranged: bool,
3449 cardinality: Option<usize>,
3450 filters: FilterCharacteristics,
3451 input: usize,
3452 enable_join_prioritize_arranged: bool,
3453 ) -> Self {
3454 if enable_join_prioritize_arranged {
3455 Self::V2(JoinInputCharacteristicsV2::new(
3456 unique_key,
3457 key_length,
3458 arranged,
3459 cardinality,
3460 filters,
3461 input,
3462 ))
3463 } else {
3464 Self::V1(JoinInputCharacteristicsV1::new(
3465 unique_key,
3466 key_length,
3467 arranged,
3468 cardinality,
3469 filters,
3470 input,
3471 ))
3472 }
3473 }
3474
3475 /// Turns the instance into a String to be printed in EXPLAIN.
3476 pub fn explain(&self) -> String {
3477 match self {
3478 Self::V1(jic) => jic.explain(),
3479 Self::V2(jic) => jic.explain(),
3480 }
3481 }
3482
3483 /// Whether the join input described by `self` is arranged.
3484 pub fn arranged(&self) -> bool {
3485 match self {
3486 Self::V1(jic) => jic.arranged,
3487 Self::V2(jic) => jic.arranged,
3488 }
3489 }
3490
3491 /// Returns the `FilterCharacteristics` for the join input described by `self`.
3492 pub fn filters(&mut self) -> &mut FilterCharacteristics {
3493 match self {
3494 Self::V1(jic) => &mut jic.filters,
3495 Self::V2(jic) => &mut jic.filters,
3496 }
3497 }
3498}
3499
3500/// Newer version of `JoinInputCharacteristics`, with `enable_join_prioritize_arranged` turned on.
3501#[derive(
3502 Eq, PartialEq, Ord, PartialOrd, Debug, Clone, Serialize, Deserialize, Hash, MzReflect, Arbitrary,
3503)]
3504pub struct JoinInputCharacteristicsV2 {
3505 /// An excellent indication that record count will not increase.
3506 pub unique_key: bool,
3507 /// Cross joins are bad.
3508 /// (`key_length > 0` also implies that it is not a cross join. However, we need to note cross
3509 /// joins in a separate field, because not being a cross join is more important than `arranged`,
3510 /// but otherwise `key_length` is less important than `arranged`.)
3511 pub not_cross: bool,
3512 /// Indicates that there will be no additional in-memory footprint.
3513 pub arranged: bool,
3514 /// A weaker signal that record count will not increase.
3515 pub key_length: usize,
3516 /// Estimated cardinality (lower is better)
3517 pub cardinality: Option<std::cmp::Reverse<usize>>,
3518 /// Characteristics of the filter that is applied at this input.
3519 pub filters: FilterCharacteristics,
3520 /// We want to prefer input earlier in the input list, for stability of ordering.
3521 pub input: std::cmp::Reverse<usize>,
3522}
3523
3524impl JoinInputCharacteristicsV2 {
3525 /// Creates a new instance with the given characteristics.
3526 pub fn new(
3527 unique_key: bool,
3528 key_length: usize,
3529 arranged: bool,
3530 cardinality: Option<usize>,
3531 filters: FilterCharacteristics,
3532 input: usize,
3533 ) -> Self {
3534 Self {
3535 unique_key,
3536 not_cross: key_length > 0,
3537 arranged,
3538 key_length,
3539 cardinality: cardinality.map(std::cmp::Reverse),
3540 filters,
3541 input: std::cmp::Reverse(input),
3542 }
3543 }
3544
3545 /// Turns the instance into a String to be printed in EXPLAIN.
3546 pub fn explain(&self) -> String {
3547 let mut e = "".to_owned();
3548 if self.unique_key {
3549 e.push_str("U");
3550 }
3551 // Don't need to print `not_cross`, because that is visible in the printed key.
3552 // if !self.not_cross {
3553 // e.push_str("C");
3554 // }
3555 for _ in 0..self.key_length {
3556 e.push_str("K");
3557 }
3558 if self.arranged {
3559 e.push_str("A");
3560 }
3561 if let Some(std::cmp::Reverse(cardinality)) = self.cardinality {
3562 e.push_str(&format!("|{cardinality}|"));
3563 }
3564 e.push_str(&self.filters.explain());
3565 e
3566 }
3567}
3568
3569/// Old version of `JoinInputCharacteristics`, with `enable_join_prioritize_arranged` turned off.
3570#[derive(
3571 Eq, PartialEq, Ord, PartialOrd, Debug, Clone, Serialize, Deserialize, Hash, MzReflect, Arbitrary,
3572)]
3573pub struct JoinInputCharacteristicsV1 {
3574 /// An excellent indication that record count will not increase.
3575 pub unique_key: bool,
3576 /// A weaker signal that record count will not increase.
3577 pub key_length: usize,
3578 /// Indicates that there will be no additional in-memory footprint.
3579 pub arranged: bool,
3580 /// Estimated cardinality (lower is better)
3581 pub cardinality: Option<std::cmp::Reverse<usize>>,
3582 /// Characteristics of the filter that is applied at this input.
3583 pub filters: FilterCharacteristics,
3584 /// We want to prefer input earlier in the input list, for stability of ordering.
3585 pub input: std::cmp::Reverse<usize>,
3586}
3587
3588impl JoinInputCharacteristicsV1 {
3589 /// Creates a new instance with the given characteristics.
3590 pub fn new(
3591 unique_key: bool,
3592 key_length: usize,
3593 arranged: bool,
3594 cardinality: Option<usize>,
3595 filters: FilterCharacteristics,
3596 input: usize,
3597 ) -> Self {
3598 Self {
3599 unique_key,
3600 key_length,
3601 arranged,
3602 cardinality: cardinality.map(std::cmp::Reverse),
3603 filters,
3604 input: std::cmp::Reverse(input),
3605 }
3606 }
3607
3608 /// Turns the instance into a String to be printed in EXPLAIN.
3609 pub fn explain(&self) -> String {
3610 let mut e = "".to_owned();
3611 if self.unique_key {
3612 e.push_str("U");
3613 }
3614 for _ in 0..self.key_length {
3615 e.push_str("K");
3616 }
3617 if self.arranged {
3618 e.push_str("A");
3619 }
3620 if let Some(std::cmp::Reverse(cardinality)) = self.cardinality {
3621 e.push_str(&format!("|{cardinality}|"));
3622 }
3623 e.push_str(&self.filters.explain());
3624 e
3625 }
3626}
3627
3628/// Instructions for finishing the result of a query.
3629///
3630/// The primary reason for the existence of this structure and attendant code
3631/// is that SQL's ORDER BY requires sorting rows (as already implied by the
3632/// keywords), whereas much of the rest of SQL is defined in terms of unordered
3633/// multisets. But as it turns out, the same idea can be used to optimize
3634/// trivial peeks.
3635///
3636/// The generic parameters are for accommodating prepared statement parameters in
3637/// `limit` and `offset`: the planner can hold these fields as HirScalarExpr long enough to call
3638/// `bind_parameters` on them.
3639#[derive(Arbitrary, Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
3640pub struct RowSetFinishing<L = NonNeg<i64>, O = usize> {
3641 /// Order rows by the given columns.
3642 pub order_by: Vec<ColumnOrder>,
3643 /// Include only as many rows (after offset).
3644 pub limit: Option<L>,
3645 /// Omit as many rows.
3646 pub offset: O,
3647 /// Include only given columns.
3648 pub project: Vec<usize>,
3649}
3650
3651impl RustType<ProtoRowSetFinishing> for RowSetFinishing {
3652 fn into_proto(&self) -> ProtoRowSetFinishing {
3653 ProtoRowSetFinishing {
3654 order_by: self.order_by.into_proto(),
3655 limit: self.limit.into_proto(),
3656 offset: self.offset.into_proto(),
3657 project: self.project.into_proto(),
3658 }
3659 }
3660
3661 fn from_proto(x: ProtoRowSetFinishing) -> Result<Self, TryFromProtoError> {
3662 Ok(RowSetFinishing {
3663 order_by: x.order_by.into_rust()?,
3664 limit: x.limit.into_rust()?,
3665 offset: x.offset.into_rust()?,
3666 project: x.project.into_rust()?,
3667 })
3668 }
3669}
3670
3671impl<L> RowSetFinishing<L> {
3672 /// Returns a trivial finishing, i.e., that does nothing to the result set.
3673 pub fn trivial(arity: usize) -> RowSetFinishing<L> {
3674 RowSetFinishing {
3675 order_by: Vec::new(),
3676 limit: None,
3677 offset: 0,
3678 project: (0..arity).collect(),
3679 }
3680 }
3681 /// True if the finishing does nothing to any result set.
3682 pub fn is_trivial(&self, arity: usize) -> bool {
3683 self.limit.is_none()
3684 && self.order_by.is_empty()
3685 && self.offset == 0
3686 && self.project.iter().copied().eq(0..arity)
3687 }
3688 /// True if the finishing does not require an ORDER BY.
3689 ///
3690 /// LIMIT and OFFSET without an ORDER BY _are_ streamable: without an
3691 /// explicit ordering we will skip an arbitrary bag of elements and return
3692 /// the first arbitrary elements in the remaining bag. The result semantics
3693 /// are still correct but maybe surprising for some users.
3694 pub fn is_streamable(&self, arity: usize) -> bool {
3695 self.order_by.is_empty() && self.project.iter().copied().eq(0..arity)
3696 }
3697}
3698
3699impl RowSetFinishing<NonNeg<i64>, usize> {
3700 /// The number of rows needed from before the finishing to evaluate the finishing:
3701 /// offset + limit.
3702 ///
3703 /// If it returns None, then we need all the rows.
3704 pub fn num_rows_needed(&self) -> Option<usize> {
3705 self.limit
3706 .as_ref()
3707 .map(|l| usize::cast_from(u64::from(l.clone())) + self.offset)
3708 }
3709}
3710
3711impl RowSetFinishing {
3712 /// Applies finishing actions to a [`RowCollection`], and reports the total
3713 /// time it took to run.
3714 ///
3715 /// Returns a [`SortedRowCollectionIter`] that contains all of the response data, as
3716 /// well as the size of the response in bytes.
3717 pub fn finish(
3718 &self,
3719 rows: RowCollection,
3720 max_result_size: u64,
3721 max_returned_query_size: Option<u64>,
3722 duration_histogram: &Histogram,
3723 ) -> Result<(SortedRowCollectionIter, usize), String> {
3724 let now = Instant::now();
3725 let result = self.finish_inner(rows, max_result_size, max_returned_query_size);
3726 let duration = now.elapsed();
3727 duration_histogram.observe(duration.as_secs_f64());
3728
3729 result
3730 }
3731
3732 /// Implementation for [`RowSetFinishing::finish`].
3733 fn finish_inner(
3734 &self,
3735 rows: RowCollection,
3736 max_result_size: u64,
3737 max_returned_query_size: Option<u64>,
3738 ) -> Result<(SortedRowCollectionIter, usize), String> {
3739 // How much additional memory is required to make a sorted view.
3740 let sorted_view_mem = rows.entries().saturating_mul(std::mem::size_of::<usize>());
3741 let required_memory = rows.byte_len().saturating_add(sorted_view_mem);
3742
3743 // Bail if creating the sorted view would require us to use too much memory.
3744 if required_memory > usize::cast_from(max_result_size) {
3745 let max_bytes = ByteSize::b(max_result_size);
3746 return Err(format!("result exceeds max size of {max_bytes}",));
3747 }
3748
3749 let sorted_view = rows.sorted_view(&self.order_by);
3750 let mut iter = sorted_view
3751 .into_row_iter()
3752 .apply_offset(self.offset)
3753 .with_projection(self.project.clone());
3754
3755 if let Some(limit) = self.limit {
3756 let limit = u64::from(limit);
3757 let limit = usize::cast_from(limit);
3758 iter = iter.with_limit(limit);
3759 };
3760
3761 // TODO(parkmycar): Re-think how we can calculate the total response size without
3762 // having to iterate through the entire collection of Rows, while still
3763 // respecting the LIMIT, OFFSET, and projections.
3764 //
3765 // Note: It feels a bit bad always calculating the response size, but we almost
3766 // always need it to either check the `max_returned_query_size`, or for reporting
3767 // in the query history.
3768 let response_size: usize = iter.clone().map(|row| row.data().len()).sum();
3769
3770 // Bail if we would end up returning more data to the client than they can support.
3771 if let Some(max) = max_returned_query_size {
3772 if response_size > usize::cast_from(max) {
3773 let max_bytes = ByteSize::b(max);
3774 return Err(format!("result exceeds max size of {max_bytes}"));
3775 }
3776 }
3777
3778 Ok((iter, response_size))
3779 }
3780}
3781
3782/// A [RowSetFinishing] that can be repeatedly applied to batches of updates (in
3783/// a [RowCollection]) and keeps track of the remaining limit, offset, and cap
3784/// on query result size.
3785#[derive(Debug)]
3786pub struct RowSetFinishingIncremental {
3787 /// Include only as many rows (after offset).
3788 pub remaining_limit: Option<usize>,
3789 /// Omit as many rows.
3790 pub remaining_offset: usize,
3791 /// The maximum allowed result size, as requested by the client.
3792 pub max_returned_query_size: Option<u64>,
3793 /// Tracks our remaining allowed budget for result size.
3794 pub remaining_max_returned_query_size: Option<u64>,
3795 /// Include only given columns.
3796 pub project: Vec<usize>,
3797}
3798
3799impl RowSetFinishingIncremental {
3800 /// Turns the given [RowSetFinishing] into a [RowSetFinishingIncremental].
3801 /// Can only be used when [is_streamable](RowSetFinishing::is_streamable) is
3802 /// `true`.
3803 ///
3804 /// # Panics
3805 ///
3806 /// Panics if the result is not streamable, that is it has an ORDER BY.
3807 pub fn new(
3808 offset: usize,
3809 limit: Option<NonNeg<i64>>,
3810 project: Vec<usize>,
3811 max_returned_query_size: Option<u64>,
3812 ) -> Self {
3813 let limit = limit.map(|l| {
3814 let l = u64::from(l);
3815 let l = usize::cast_from(l);
3816 l
3817 });
3818
3819 RowSetFinishingIncremental {
3820 remaining_limit: limit,
3821 remaining_offset: offset,
3822 max_returned_query_size,
3823 remaining_max_returned_query_size: max_returned_query_size,
3824 project,
3825 }
3826 }
3827
3828 /// Applies finishing actions to the given [`RowCollection`], and reports
3829 /// the total time it took to run.
3830 ///
3831 /// Returns a [`SortedRowCollectionIter`] that contains all of the response
3832 /// data.
3833 pub fn finish_incremental(
3834 &mut self,
3835 rows: RowCollection,
3836 max_result_size: u64,
3837 duration_histogram: &Histogram,
3838 ) -> Result<SortedRowCollectionIter, String> {
3839 let now = Instant::now();
3840 let result = self.finish_incremental_inner(rows, max_result_size);
3841 let duration = now.elapsed();
3842 duration_histogram.observe(duration.as_secs_f64());
3843
3844 result
3845 }
3846
3847 fn finish_incremental_inner(
3848 &mut self,
3849 rows: RowCollection,
3850 max_result_size: u64,
3851 ) -> Result<SortedRowCollectionIter, String> {
3852 // How much additional memory is required to make a sorted view.
3853 let sorted_view_mem = rows.entries().saturating_mul(std::mem::size_of::<usize>());
3854 let required_memory = rows.byte_len().saturating_add(sorted_view_mem);
3855
3856 // Bail if creating the sorted view would require us to use too much memory.
3857 if required_memory > usize::cast_from(max_result_size) {
3858 let max_bytes = ByteSize::b(max_result_size);
3859 return Err(format!("total result exceeds max size of {max_bytes}",));
3860 }
3861
3862 let batch_num_rows = rows.count(0, None);
3863
3864 let sorted_view = rows.sorted_view(&[]);
3865 let mut iter = sorted_view
3866 .into_row_iter()
3867 .apply_offset(self.remaining_offset)
3868 .with_projection(self.project.clone());
3869
3870 if let Some(limit) = self.remaining_limit {
3871 iter = iter.with_limit(limit);
3872 };
3873
3874 self.remaining_offset = self.remaining_offset.saturating_sub(batch_num_rows);
3875 if let Some(remaining_limit) = self.remaining_limit.as_mut() {
3876 *remaining_limit -= iter.count();
3877 }
3878
3879 // TODO(parkmycar): Re-think how we can calculate the total response size without
3880 // having to iterate through the entire collection of Rows, while still
3881 // respecting the LIMIT, OFFSET, and projections.
3882 //
3883 // Note: It feels a bit bad always calculating the response size, but we almost
3884 // always need it to either check the `max_returned_query_size`, or for reporting
3885 // in the query history.
3886 let response_size: usize = iter.clone().map(|row| row.data().len()).sum();
3887
3888 // Bail if we would end up returning more data to the client than they can support.
3889 if let Some(max) = self.remaining_max_returned_query_size {
3890 if response_size > usize::cast_from(max) {
3891 let max_bytes = ByteSize::b(self.max_returned_query_size.expect("known to exist"));
3892 return Err(format!("total result exceeds max size of {max_bytes}"));
3893 }
3894 }
3895
3896 Ok(iter)
3897 }
3898}
3899
3900/// Compare `left` and `right` using `order`. If that doesn't produce a strict
3901/// ordering, call `tiebreaker`.
3902pub fn compare_columns<F>(
3903 order: &[ColumnOrder],
3904 left: &[Datum],
3905 right: &[Datum],
3906 tiebreaker: F,
3907) -> Ordering
3908where
3909 F: Fn() -> Ordering,
3910{
3911 for order in order {
3912 let cmp = match (&left[order.column], &right[order.column]) {
3913 (Datum::Null, Datum::Null) => Ordering::Equal,
3914 (Datum::Null, _) => {
3915 if order.nulls_last {
3916 Ordering::Greater
3917 } else {
3918 Ordering::Less
3919 }
3920 }
3921 (_, Datum::Null) => {
3922 if order.nulls_last {
3923 Ordering::Less
3924 } else {
3925 Ordering::Greater
3926 }
3927 }
3928 (lval, rval) => {
3929 if order.desc {
3930 rval.cmp(lval)
3931 } else {
3932 lval.cmp(rval)
3933 }
3934 }
3935 };
3936 if cmp != Ordering::Equal {
3937 return cmp;
3938 }
3939 }
3940 tiebreaker()
3941}
3942
3943/// Describe a window frame, e.g. `RANGE UNBOUNDED PRECEDING` or
3944/// `ROWS BETWEEN 5 PRECEDING AND CURRENT ROW`.
3945///
3946/// Window frames define a subset of the partition , and only a subset of
3947/// window functions make use of the window frame.
3948#[derive(
3949 Arbitrary, Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize, Hash, MzReflect,
3950)]
3951pub struct WindowFrame {
3952 /// ROWS, RANGE or GROUPS
3953 pub units: WindowFrameUnits,
3954 /// Where the frame starts
3955 pub start_bound: WindowFrameBound,
3956 /// Where the frame ends
3957 pub end_bound: WindowFrameBound,
3958}
3959
3960impl Display for WindowFrame {
3961 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
3962 write!(
3963 f,
3964 "{} between {} and {}",
3965 self.units, self.start_bound, self.end_bound
3966 )
3967 }
3968}
3969
3970impl WindowFrame {
3971 /// Return the default window frame used when one is not explicitly defined
3972 pub fn default() -> Self {
3973 WindowFrame {
3974 units: WindowFrameUnits::Range,
3975 start_bound: WindowFrameBound::UnboundedPreceding,
3976 end_bound: WindowFrameBound::CurrentRow,
3977 }
3978 }
3979
3980 fn includes_current_row(&self) -> bool {
3981 use WindowFrameBound::*;
3982 match self.start_bound {
3983 UnboundedPreceding => match self.end_bound {
3984 UnboundedPreceding => false,
3985 OffsetPreceding(0) => true,
3986 OffsetPreceding(_) => false,
3987 CurrentRow => true,
3988 OffsetFollowing(_) => true,
3989 UnboundedFollowing => true,
3990 },
3991 OffsetPreceding(0) => match self.end_bound {
3992 UnboundedPreceding => unreachable!(),
3993 OffsetPreceding(0) => true,
3994 // Any nonzero offsets here will create an empty window
3995 OffsetPreceding(_) => false,
3996 CurrentRow => true,
3997 OffsetFollowing(_) => true,
3998 UnboundedFollowing => true,
3999 },
4000 OffsetPreceding(_) => match self.end_bound {
4001 UnboundedPreceding => unreachable!(),
4002 // Window ends at the current row
4003 OffsetPreceding(0) => true,
4004 OffsetPreceding(_) => false,
4005 CurrentRow => true,
4006 OffsetFollowing(_) => true,
4007 UnboundedFollowing => true,
4008 },
4009 CurrentRow => true,
4010 OffsetFollowing(0) => match self.end_bound {
4011 UnboundedPreceding => unreachable!(),
4012 OffsetPreceding(_) => unreachable!(),
4013 CurrentRow => unreachable!(),
4014 OffsetFollowing(_) => true,
4015 UnboundedFollowing => true,
4016 },
4017 OffsetFollowing(_) => match self.end_bound {
4018 UnboundedPreceding => unreachable!(),
4019 OffsetPreceding(_) => unreachable!(),
4020 CurrentRow => unreachable!(),
4021 OffsetFollowing(_) => false,
4022 UnboundedFollowing => false,
4023 },
4024 UnboundedFollowing => false,
4025 }
4026 }
4027}
4028
4029impl RustType<ProtoWindowFrame> for WindowFrame {
4030 fn into_proto(&self) -> ProtoWindowFrame {
4031 ProtoWindowFrame {
4032 units: Some(self.units.into_proto()),
4033 start_bound: Some(self.start_bound.into_proto()),
4034 end_bound: Some(self.end_bound.into_proto()),
4035 }
4036 }
4037
4038 fn from_proto(proto: ProtoWindowFrame) -> Result<Self, TryFromProtoError> {
4039 Ok(WindowFrame {
4040 units: proto.units.into_rust_if_some("ProtoWindowFrame::units")?,
4041 start_bound: proto
4042 .start_bound
4043 .into_rust_if_some("ProtoWindowFrame::start_bound")?,
4044 end_bound: proto
4045 .end_bound
4046 .into_rust_if_some("ProtoWindowFrame::end_bound")?,
4047 })
4048 }
4049}
4050
4051/// Describe how frame bounds are interpreted
4052#[derive(
4053 Arbitrary, Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize, Hash, MzReflect,
4054)]
4055pub enum WindowFrameUnits {
4056 /// Each row is treated as the unit of work for bounds
4057 Rows,
4058 /// Each peer group is treated as the unit of work for bounds,
4059 /// and offset-based bounds use the value of the ORDER BY expression
4060 Range,
4061 /// Each peer group is treated as the unit of work for bounds.
4062 /// Groups is currently not supported, and it is rejected during planning.
4063 Groups,
4064}
4065
4066impl Display for WindowFrameUnits {
4067 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
4068 match self {
4069 WindowFrameUnits::Rows => write!(f, "rows"),
4070 WindowFrameUnits::Range => write!(f, "range"),
4071 WindowFrameUnits::Groups => write!(f, "groups"),
4072 }
4073 }
4074}
4075
4076impl RustType<proto_window_frame::ProtoWindowFrameUnits> for WindowFrameUnits {
4077 fn into_proto(&self) -> proto_window_frame::ProtoWindowFrameUnits {
4078 use proto_window_frame::proto_window_frame_units::Kind::*;
4079 proto_window_frame::ProtoWindowFrameUnits {
4080 kind: Some(match self {
4081 WindowFrameUnits::Rows => Rows(()),
4082 WindowFrameUnits::Range => Range(()),
4083 WindowFrameUnits::Groups => Groups(()),
4084 }),
4085 }
4086 }
4087
4088 fn from_proto(
4089 proto: proto_window_frame::ProtoWindowFrameUnits,
4090 ) -> Result<Self, TryFromProtoError> {
4091 use proto_window_frame::proto_window_frame_units::Kind::*;
4092 Ok(match proto.kind {
4093 Some(Rows(())) => WindowFrameUnits::Rows,
4094 Some(Range(())) => WindowFrameUnits::Range,
4095 Some(Groups(())) => WindowFrameUnits::Groups,
4096 None => {
4097 return Err(TryFromProtoError::missing_field(
4098 "ProtoWindowFrameUnits::kind",
4099 ));
4100 }
4101 })
4102 }
4103}
4104
4105/// Specifies [WindowFrame]'s `start_bound` and `end_bound`
4106///
4107/// The order between frame bounds is significant, as Postgres enforces
4108/// some restrictions there.
4109#[derive(
4110 Arbitrary, Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash, MzReflect, PartialOrd, Ord,
4111)]
4112pub enum WindowFrameBound {
4113 /// `UNBOUNDED PRECEDING`
4114 UnboundedPreceding,
4115 /// `<N> PRECEDING`
4116 OffsetPreceding(u64),
4117 /// `CURRENT ROW`
4118 CurrentRow,
4119 /// `<N> FOLLOWING`
4120 OffsetFollowing(u64),
4121 /// `UNBOUNDED FOLLOWING`.
4122 UnboundedFollowing,
4123}
4124
4125impl Display for WindowFrameBound {
4126 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
4127 match self {
4128 WindowFrameBound::UnboundedPreceding => write!(f, "unbounded preceding"),
4129 WindowFrameBound::OffsetPreceding(offset) => write!(f, "{} preceding", offset),
4130 WindowFrameBound::CurrentRow => write!(f, "current row"),
4131 WindowFrameBound::OffsetFollowing(offset) => write!(f, "{} following", offset),
4132 WindowFrameBound::UnboundedFollowing => write!(f, "unbounded following"),
4133 }
4134 }
4135}
4136
4137impl RustType<proto_window_frame::ProtoWindowFrameBound> for WindowFrameBound {
4138 fn into_proto(&self) -> proto_window_frame::ProtoWindowFrameBound {
4139 use proto_window_frame::proto_window_frame_bound::Kind::*;
4140 proto_window_frame::ProtoWindowFrameBound {
4141 kind: Some(match self {
4142 WindowFrameBound::UnboundedPreceding => UnboundedPreceding(()),
4143 WindowFrameBound::OffsetPreceding(offset) => OffsetPreceding(*offset),
4144 WindowFrameBound::CurrentRow => CurrentRow(()),
4145 WindowFrameBound::OffsetFollowing(offset) => OffsetFollowing(*offset),
4146 WindowFrameBound::UnboundedFollowing => UnboundedFollowing(()),
4147 }),
4148 }
4149 }
4150
4151 fn from_proto(x: proto_window_frame::ProtoWindowFrameBound) -> Result<Self, TryFromProtoError> {
4152 use proto_window_frame::proto_window_frame_bound::Kind::*;
4153 Ok(match x.kind {
4154 Some(UnboundedPreceding(())) => WindowFrameBound::UnboundedPreceding,
4155 Some(OffsetPreceding(offset)) => WindowFrameBound::OffsetPreceding(offset),
4156 Some(CurrentRow(())) => WindowFrameBound::CurrentRow,
4157 Some(OffsetFollowing(offset)) => WindowFrameBound::OffsetFollowing(offset),
4158 Some(UnboundedFollowing(())) => WindowFrameBound::UnboundedFollowing,
4159 None => {
4160 return Err(TryFromProtoError::missing_field(
4161 "ProtoWindowFrameBound::kind",
4162 ));
4163 }
4164 })
4165 }
4166}
4167
4168/// Maximum iterations for a LetRec.
4169#[derive(
4170 Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize, Arbitrary,
4171)]
4172pub struct LetRecLimit {
4173 /// Maximum number of iterations to evaluate.
4174 pub max_iters: NonZeroU64,
4175 /// Whether to throw an error when reaching the above limit.
4176 /// If true, we simply use the current contents of each Id as the final result.
4177 pub return_at_limit: bool,
4178}
4179
4180impl LetRecLimit {
4181 /// Compute the smallest limit from a Vec of `LetRecLimit`s.
4182 pub fn min_max_iter(limits: &Vec<Option<LetRecLimit>>) -> Option<u64> {
4183 limits
4184 .iter()
4185 .filter_map(|l| l.as_ref().map(|l| l.max_iters.get()))
4186 .min()
4187 }
4188
4189 /// The default value of `LetRecLimit::return_at_limit` when using the RECURSION LIMIT option of
4190 /// WMR without ERROR AT or RETURN AT.
4191 pub const RETURN_AT_LIMIT_DEFAULT: bool = false;
4192}
4193
4194impl Display for LetRecLimit {
4195 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
4196 write!(f, "[recursion_limit={}", self.max_iters)?;
4197 if self.return_at_limit != LetRecLimit::RETURN_AT_LIMIT_DEFAULT {
4198 write!(f, ", return_at_limit")?;
4199 }
4200 write!(f, "]")
4201 }
4202}
4203
4204/// For a global Get, this indicates whether we are going to read from Persist or from an index.
4205/// (See comment in MirRelationExpr::Get.)
4206#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize, Hash, Arbitrary)]
4207pub enum AccessStrategy {
4208 /// It's either a local Get (a CTE), or unknown at the time.
4209 /// `prune_and_annotate_dataflow_index_imports` decides it for global Gets, and thus switches to
4210 /// one of the other variants.
4211 UnknownOrLocal,
4212 /// The Get will read from Persist.
4213 Persist,
4214 /// The Get will read from an index or indexes: (index id, how the index will be used).
4215 Index(Vec<(GlobalId, IndexUsageType)>),
4216 /// The Get will read a collection that is computed by the same dataflow, but in a different
4217 /// `BuildDesc` in `objects_to_build`.
4218 SameDataflow,
4219}
4220
4221#[cfg(test)]
4222mod tests {
4223 use mz_ore::assert_ok;
4224 use mz_proto::protobuf_roundtrip;
4225 use mz_repr::explain::text::text_string_at;
4226 use proptest::prelude::*;
4227
4228 use crate::explain::HumanizedExplain;
4229
4230 use super::*;
4231
4232 proptest! {
4233 #[mz_ore::test]
4234 fn column_order_protobuf_roundtrip(expect in any::<ColumnOrder>()) {
4235 let actual = protobuf_roundtrip::<_, ProtoColumnOrder>(&expect);
4236 assert_ok!(actual);
4237 assert_eq!(actual.unwrap(), expect);
4238 }
4239 }
4240
4241 proptest! {
4242 #[mz_ore::test]
4243 #[cfg_attr(miri, ignore)] // error: unsupported operation: can't call foreign function `decContextDefault` on OS `linux`
4244 fn aggregate_expr_protobuf_roundtrip(expect in any::<AggregateExpr>()) {
4245 let actual = protobuf_roundtrip::<_, ProtoAggregateExpr>(&expect);
4246 assert_ok!(actual);
4247 assert_eq!(actual.unwrap(), expect);
4248 }
4249 }
4250
4251 proptest! {
4252 #[mz_ore::test]
4253 fn window_frame_units_protobuf_roundtrip(expect in any::<WindowFrameUnits>()) {
4254 let actual = protobuf_roundtrip::<_, proto_window_frame::ProtoWindowFrameUnits>(&expect);
4255 assert_ok!(actual);
4256 assert_eq!(actual.unwrap(), expect);
4257 }
4258 }
4259
4260 proptest! {
4261 #[mz_ore::test]
4262 fn window_frame_bound_protobuf_roundtrip(expect in any::<WindowFrameBound>()) {
4263 let actual = protobuf_roundtrip::<_, proto_window_frame::ProtoWindowFrameBound>(&expect);
4264 assert_ok!(actual);
4265 assert_eq!(actual.unwrap(), expect);
4266 }
4267 }
4268
4269 proptest! {
4270 #[mz_ore::test]
4271 fn window_frame_protobuf_roundtrip(expect in any::<WindowFrame>()) {
4272 let actual = protobuf_roundtrip::<_, ProtoWindowFrame>(&expect);
4273 assert_ok!(actual);
4274 assert_eq!(actual.unwrap(), expect);
4275 }
4276 }
4277
4278 #[mz_ore::test]
4279 fn test_row_set_finishing_as_text() {
4280 let finishing = RowSetFinishing {
4281 order_by: vec![ColumnOrder {
4282 column: 4,
4283 desc: true,
4284 nulls_last: true,
4285 }],
4286 limit: Some(NonNeg::try_from(7).unwrap()),
4287 offset: Default::default(),
4288 project: vec![1, 3, 4, 5],
4289 };
4290
4291 let mode = HumanizedExplain::new(false);
4292 let expr = mode.expr(&finishing, None);
4293
4294 let act = text_string_at(&expr, mz_ore::str::Indent::default);
4295
4296 let exp = {
4297 use mz_ore::fmt::FormatBuffer;
4298 let mut s = String::new();
4299 write!(&mut s, "Finish");
4300 write!(&mut s, " order_by=[#4 desc nulls_last]");
4301 write!(&mut s, " limit=7");
4302 write!(&mut s, " output=[#1, #3..=#5]");
4303 writeln!(&mut s, "");
4304 s
4305 };
4306
4307 assert_eq!(act, exp);
4308 }
4309}
4310
4311/// An iterator over AST structures, which calls out nodes in difference.
4312///
4313/// The iterators visit two ASTs in tandem, continuing as long as the AST node data matches,
4314/// and yielding an output pair as soon as the AST nodes do not match. Their intent is to call
4315/// attention to the moments in the ASTs where they differ, and incidentally a stack-free way
4316/// to compare two ASTs.
4317mod structured_diff {
4318
4319 use super::MirRelationExpr;
4320
4321 /// An iterator over structured differences between two `MirRelationExpr` instances.
4322 pub struct MreDiff<'a> {
4323 /// Pairs of expressions that must still be compared.
4324 todo: Vec<(&'a MirRelationExpr, &'a MirRelationExpr)>,
4325 }
4326
4327 impl<'a> MreDiff<'a> {
4328 /// Create a new `MirRelationExpr` structured difference.
4329 pub fn new(expr1: &'a MirRelationExpr, expr2: &'a MirRelationExpr) -> Self {
4330 MreDiff {
4331 todo: vec![(expr1, expr2)],
4332 }
4333 }
4334 }
4335
4336 impl<'a> Iterator for MreDiff<'a> {
4337 // Pairs of expressions that do not match.
4338 type Item = (&'a MirRelationExpr, &'a MirRelationExpr);
4339
4340 fn next(&mut self) -> Option<Self::Item> {
4341 while let Some((expr1, expr2)) = self.todo.pop() {
4342 match (expr1, expr2) {
4343 (
4344 MirRelationExpr::Constant {
4345 rows: rows1,
4346 typ: typ1,
4347 },
4348 MirRelationExpr::Constant {
4349 rows: rows2,
4350 typ: typ2,
4351 },
4352 ) => {
4353 if rows1 != rows2 || typ1 != typ2 {
4354 return Some((expr1, expr2));
4355 }
4356 }
4357 (
4358 MirRelationExpr::Get {
4359 id: id1,
4360 typ: typ1,
4361 access_strategy: as1,
4362 },
4363 MirRelationExpr::Get {
4364 id: id2,
4365 typ: typ2,
4366 access_strategy: as2,
4367 },
4368 ) => {
4369 if id1 != id2 || typ1 != typ2 || as1 != as2 {
4370 return Some((expr1, expr2));
4371 }
4372 }
4373 (
4374 MirRelationExpr::Let {
4375 id: id1,
4376 body: body1,
4377 value: value1,
4378 },
4379 MirRelationExpr::Let {
4380 id: id2,
4381 body: body2,
4382 value: value2,
4383 },
4384 ) => {
4385 if id1 != id2 {
4386 return Some((expr1, expr2));
4387 } else {
4388 self.todo.push((body1, body2));
4389 self.todo.push((value1, value2));
4390 }
4391 }
4392 (
4393 MirRelationExpr::LetRec {
4394 ids: ids1,
4395 body: body1,
4396 values: values1,
4397 limits: limits1,
4398 },
4399 MirRelationExpr::LetRec {
4400 ids: ids2,
4401 body: body2,
4402 values: values2,
4403 limits: limits2,
4404 },
4405 ) => {
4406 if ids1 != ids2 || values1.len() != values2.len() || limits1 != limits2 {
4407 return Some((expr1, expr2));
4408 } else {
4409 self.todo.push((body1, body2));
4410 self.todo.extend(values1.iter().zip(values2.iter()));
4411 }
4412 }
4413 (
4414 MirRelationExpr::Project {
4415 outputs: outputs1,
4416 input: input1,
4417 },
4418 MirRelationExpr::Project {
4419 outputs: outputs2,
4420 input: input2,
4421 },
4422 ) => {
4423 if outputs1 != outputs2 {
4424 return Some((expr1, expr2));
4425 } else {
4426 self.todo.push((input1, input2));
4427 }
4428 }
4429 (
4430 MirRelationExpr::Map {
4431 scalars: scalars1,
4432 input: input1,
4433 },
4434 MirRelationExpr::Map {
4435 scalars: scalars2,
4436 input: input2,
4437 },
4438 ) => {
4439 if scalars1 != scalars2 {
4440 return Some((expr1, expr2));
4441 } else {
4442 self.todo.push((input1, input2));
4443 }
4444 }
4445 (
4446 MirRelationExpr::Filter {
4447 predicates: predicates1,
4448 input: input1,
4449 },
4450 MirRelationExpr::Filter {
4451 predicates: predicates2,
4452 input: input2,
4453 },
4454 ) => {
4455 if predicates1 != predicates2 {
4456 return Some((expr1, expr2));
4457 } else {
4458 self.todo.push((input1, input2));
4459 }
4460 }
4461 (
4462 MirRelationExpr::FlatMap {
4463 input: input1,
4464 func: func1,
4465 exprs: exprs1,
4466 },
4467 MirRelationExpr::FlatMap {
4468 input: input2,
4469 func: func2,
4470 exprs: exprs2,
4471 },
4472 ) => {
4473 if func1 != func2 || exprs1 != exprs2 {
4474 return Some((expr1, expr2));
4475 } else {
4476 self.todo.push((input1, input2));
4477 }
4478 }
4479 (
4480 MirRelationExpr::Join {
4481 inputs: inputs1,
4482 equivalences: eq1,
4483 implementation: impl1,
4484 },
4485 MirRelationExpr::Join {
4486 inputs: inputs2,
4487 equivalences: eq2,
4488 implementation: impl2,
4489 },
4490 ) => {
4491 if inputs1.len() != inputs2.len() || eq1 != eq2 || impl1 != impl2 {
4492 return Some((expr1, expr2));
4493 } else {
4494 self.todo.extend(inputs1.iter().zip(inputs2.iter()));
4495 }
4496 }
4497 (
4498 MirRelationExpr::Reduce {
4499 aggregates: aggregates1,
4500 input: inputs1,
4501 group_key: gk1,
4502 monotonic: m1,
4503 expected_group_size: egs1,
4504 },
4505 MirRelationExpr::Reduce {
4506 aggregates: aggregates2,
4507 input: inputs2,
4508 group_key: gk2,
4509 monotonic: m2,
4510 expected_group_size: egs2,
4511 },
4512 ) => {
4513 if aggregates1 != aggregates2 || gk1 != gk2 || m1 != m2 || egs1 != egs2 {
4514 return Some((expr1, expr2));
4515 } else {
4516 self.todo.push((inputs1, inputs2));
4517 }
4518 }
4519 (
4520 MirRelationExpr::TopK {
4521 group_key: gk1,
4522 order_key: order1,
4523 input: input1,
4524 limit: l1,
4525 offset: o1,
4526 monotonic: m1,
4527 expected_group_size: egs1,
4528 },
4529 MirRelationExpr::TopK {
4530 group_key: gk2,
4531 order_key: order2,
4532 input: input2,
4533 limit: l2,
4534 offset: o2,
4535 monotonic: m2,
4536 expected_group_size: egs2,
4537 },
4538 ) => {
4539 if order1 != order2
4540 || gk1 != gk2
4541 || l1 != l2
4542 || o1 != o2
4543 || m1 != m2
4544 || egs1 != egs2
4545 {
4546 return Some((expr1, expr2));
4547 } else {
4548 self.todo.push((input1, input2));
4549 }
4550 }
4551 (
4552 MirRelationExpr::Negate { input: input1 },
4553 MirRelationExpr::Negate { input: input2 },
4554 ) => {
4555 self.todo.push((input1, input2));
4556 }
4557 (
4558 MirRelationExpr::Threshold { input: input1 },
4559 MirRelationExpr::Threshold { input: input2 },
4560 ) => {
4561 self.todo.push((input1, input2));
4562 }
4563 (
4564 MirRelationExpr::Union {
4565 base: base1,
4566 inputs: inputs1,
4567 },
4568 MirRelationExpr::Union {
4569 base: base2,
4570 inputs: inputs2,
4571 },
4572 ) => {
4573 if inputs1.len() != inputs2.len() {
4574 return Some((expr1, expr2));
4575 } else {
4576 self.todo.push((base1, base2));
4577 self.todo.extend(inputs1.iter().zip(inputs2.iter()));
4578 }
4579 }
4580 (
4581 MirRelationExpr::ArrangeBy {
4582 keys: keys1,
4583 input: input1,
4584 },
4585 MirRelationExpr::ArrangeBy {
4586 keys: keys2,
4587 input: input2,
4588 },
4589 ) => {
4590 if keys1 != keys2 {
4591 return Some((expr1, expr2));
4592 } else {
4593 self.todo.push((input1, input2));
4594 }
4595 }
4596 _ => {
4597 return Some((expr1, expr2));
4598 }
4599 }
4600 }
4601 None
4602 }
4603 }
4604}