1use std::collections::BTreeMap;
11
12use itertools::Itertools;
13use mz_expr::{AccessStrategy, EvalError, Id, LocalId, MirRelationExpr, MirScalarExpr};
14use mz_lowertest::*;
15use mz_ore::cast::CastFrom;
16use mz_ore::result::ResultExt;
17use mz_ore::str::separated;
18use mz_repr::explain::{DummyHumanizer, ExprHumanizer};
19use mz_repr::{ColumnType, Diff, GlobalId, RelationType, Row, ScalarType};
20use mz_repr_test_util::*;
21use proc_macro2::TokenTree;
22use serde::{Deserialize, Serialize};
23use serde_json::Value;
24
25pub fn build_scalar(s: &str) -> Result<MirScalarExpr, String> {
29 deserialize(
30 &mut tokenize(s)?.into_iter(),
31 "MirScalarExpr",
32 &mut MirScalarExprDeserializeContext::default(),
33 )
34}
35
36pub fn build_rel(s: &str, catalog: &TestCatalog) -> Result<MirRelationExpr, String> {
40 deserialize(
41 &mut tokenize(s)?.into_iter(),
42 "MirRelationExpr",
43 &mut MirRelationExprDeserializeContext::new(catalog),
44 )
45}
46
47pub fn json_to_spec(rel_json: &str, catalog: &TestCatalog) -> (String, Vec<String>) {
55 let mut ctx = MirRelationExprDeserializeContext::new(catalog);
56 let spec = serialize::<MirRelationExpr, _>(
57 &serde_json::from_str(rel_json).unwrap(),
58 "MirRelationExpr",
59 &mut ctx,
60 );
61 let mut source_defs = ctx
62 .list_scope_references()
63 .map(|(name, typ)| {
64 format!(
65 "(defsource {} {})",
66 name,
67 serialize_generic::<RelationType>(
68 &serde_json::to_value(typ).unwrap(),
69 "RelationType",
70 )
71 )
72 })
73 .collect::<Vec<_>>();
74 source_defs.sort();
75 (spec, source_defs)
76}
77
78#[derive(Debug, Default)]
83pub struct TestCatalog {
84 objects: BTreeMap<String, (GlobalId, RelationType)>,
85 names: BTreeMap<GlobalId, String>,
86}
87
88#[derive(Debug, Serialize, Deserialize, MzReflect)]
92enum TestCatalogCommand {
93 Defsource { name: String, typ: RelationType },
95}
96
97impl<'a> TestCatalog {
98 pub fn insert(
107 &mut self,
108 name: &str,
109 typ: RelationType,
110 transient: bool,
111 ) -> Result<GlobalId, String> {
112 if self.objects.contains_key(name) {
113 return Err(format!("Object {} already exists in catalog", name));
114 }
115 let id = if transient {
116 GlobalId::Transient(u64::cast_from(self.objects.len()))
117 } else {
118 GlobalId::User(u64::cast_from(self.objects.len()))
119 };
120 self.objects.insert(name.to_string(), (id, typ));
121 self.names.insert(id, name.to_string());
122 Ok(id)
123 }
124
125 fn get(&'a self, name: &str) -> Option<&'a (GlobalId, RelationType)> {
126 self.objects.get(name)
127 }
128
129 pub fn get_source_name(&'a self, id: &GlobalId) -> Option<&'a String> {
131 self.names.get(id)
132 }
133
134 pub fn handle_test_command(&mut self, spec: &str) -> Result<(), String> {
140 let mut stream_iter = tokenize(spec)?.into_iter();
141 while let Some(command) = deserialize_optional_generic::<TestCatalogCommand, _>(
142 &mut stream_iter,
143 "TestCatalogCommand",
144 )? {
145 match command {
146 TestCatalogCommand::Defsource { name, typ } => {
147 self.insert(&name, typ, false)?;
148 }
149 }
150 }
151 Ok(())
152 }
153
154 pub fn remove_transient_objects(&mut self) {
156 self.objects.retain(|_, (id, _)| {
157 if let GlobalId::Transient(_) = id {
158 false
159 } else {
160 true
161 }
162 });
163 self.names.retain(|k, _| {
164 if let GlobalId::Transient(_) = k {
165 false
166 } else {
167 true
168 }
169 });
170 }
171}
172
173impl ExprHumanizer for TestCatalog {
174 fn humanize_id(&self, id: GlobalId) -> Option<String> {
175 self.names.get(&id).map(|s| s.to_string())
176 }
177
178 fn humanize_id_unqualified(&self, id: GlobalId) -> Option<String> {
179 self.names.get(&id).map(|s| s.to_string())
180 }
181
182 fn humanize_id_parts(&self, id: GlobalId) -> Option<Vec<String>> {
183 self.humanize_id_unqualified(id).map(|name| vec![name])
184 }
185
186 fn humanize_scalar_type(&self, ty: &ScalarType, postgres_compat: bool) -> String {
187 DummyHumanizer.humanize_scalar_type(ty, postgres_compat)
188 }
189
190 fn column_names_for_id(&self, _id: GlobalId) -> Option<Vec<String>> {
191 None
192 }
193
194 fn humanize_column(&self, _id: GlobalId, _column: usize) -> Option<String> {
195 None
196 }
197
198 fn id_exists(&self, id: GlobalId) -> bool {
199 self.names.contains_key(&id)
200 }
201}
202
203#[derive(Default)]
216pub struct MirScalarExprDeserializeContext;
217
218impl MirScalarExprDeserializeContext {
219 fn build_column(&self, token: Option<TokenTree>) -> Result<MirScalarExpr, String> {
220 if let Some(TokenTree::Literal(literal)) = token {
221 return Ok(MirScalarExpr::column(
222 literal
223 .to_string()
224 .parse::<usize>()
225 .map_err_to_string_with_causes()?,
226 ));
227 }
228 Err(format!(
229 "Invalid column specification {:?}",
230 token.map(|token_tree| format!("`{}`", token_tree))
231 ))
232 }
233
234 fn build_literal_if_able<I>(
235 &self,
236 first_arg: TokenTree,
237 rest_of_stream: &mut I,
238 ) -> Result<Option<MirScalarExpr>, String>
239 where
240 I: Iterator<Item = TokenTree>,
241 {
242 match &first_arg {
243 TokenTree::Ident(i) if i.to_string().eq_ignore_ascii_case("ok") => {
244 let first_arg = if let Some(first_arg) = rest_of_stream.next() {
246 first_arg
247 } else {
248 return Err(format!("expected literal after Ident: `{}`", i));
249 };
250 match self.build_literal_ok_if_able(first_arg, rest_of_stream) {
251 Ok(Some(l)) => Ok(Some(l)),
252 _ => Err(format!("expected literal after Ident: `{}`", i)),
253 }
254 }
255 TokenTree::Ident(i) if i.to_string().eq_ignore_ascii_case("err") => {
256 let error = deserialize_generic(rest_of_stream, "EvalError")?;
257 let typ: Option<ScalarType> =
258 deserialize_optional_generic(rest_of_stream, "ScalarType")?;
259 Ok(Some(MirScalarExpr::literal(
260 Err(error),
261 typ.unwrap_or(ScalarType::Bool),
262 )))
263 }
264 _ => self.build_literal_ok_if_able(first_arg, rest_of_stream),
265 }
266 }
267
268 fn build_literal_ok_if_able<I>(
269 &self,
270 first_arg: TokenTree,
271 rest_of_stream: &mut I,
272 ) -> Result<Option<MirScalarExpr>, String>
273 where
274 I: Iterator<Item = TokenTree>,
275 {
276 match extract_literal_string(&first_arg, rest_of_stream)? {
277 Some(litval) => {
278 let littyp = get_scalar_type_or_default(&litval[..], rest_of_stream)?;
279 Ok(Some(MirScalarExpr::Literal(
280 Ok(test_spec_to_row(std::iter::once((&litval[..], &littyp)))?),
281 littyp.nullable(matches!(&litval[..], "null")),
282 )))
283 }
284 None => Ok(None),
285 }
286 }
287}
288
289impl TestDeserializeContext for MirScalarExprDeserializeContext {
290 fn override_syntax<I>(
291 &mut self,
292 first_arg: TokenTree,
293 rest_of_stream: &mut I,
294 type_name: &str,
295 ) -> Result<Option<String>, String>
296 where
297 I: Iterator<Item = TokenTree>,
298 {
299 let result = if type_name == "MirScalarExpr" {
300 match first_arg {
301 TokenTree::Punct(punct) if punct.as_char() == '#' => {
302 Some(self.build_column(rest_of_stream.next())?)
303 }
304 TokenTree::Group(_) => None,
305 symbol => self.build_literal_if_able(symbol, rest_of_stream)?,
306 }
307 } else {
308 None
309 };
310 match result {
311 Some(result) => Ok(Some(
312 serde_json::to_string(&result).map_err_to_string_with_causes()?,
313 )),
314 None => Ok(None),
315 }
316 }
317
318 fn reverse_syntax_override(&mut self, json: &Value, type_name: &str) -> Option<String> {
319 if type_name == "MirScalarExpr" {
320 let map = json.as_object().unwrap();
321 assert_eq!(map.len(), 1);
323 for (variant, data) in map.iter() {
324 match &variant[..] {
325 "Column" => {
326 return Some(format!(
327 "#{}",
328 data.as_array().unwrap()[0].as_u64().unwrap()
329 ));
330 }
331 "Literal" => {
332 let column_type: ColumnType =
333 serde_json::from_value(data.as_array().unwrap()[1].clone()).unwrap();
334 let obj = data.as_array().unwrap()[0].as_object().unwrap();
335 if let Some(inner_data) = obj.get("Ok") {
336 let row: Row = serde_json::from_value(inner_data.clone()).unwrap();
337 let result = format!(
338 "({} {})",
339 datum_to_test_spec(row.unpack_first()),
340 serialize::<ScalarType, _>(
341 &serde_json::to_value(&column_type.scalar_type).unwrap(),
342 "ScalarType",
343 self
344 )
345 );
346 return Some(result);
347 } else if let Some(inner_data) = obj.get("Err") {
348 let result = format!(
349 "(err {} {})",
350 serialize::<EvalError, _>(inner_data, "EvalError", self),
351 serialize::<ScalarType, _>(
352 &serde_json::to_value(&column_type.scalar_type).unwrap(),
353 "ScalarType",
354 self
355 ),
356 );
357 return Some(result);
358 } else {
359 unreachable!("unexpected JSON data: {:?}", obj);
360 }
361 }
362 _ => {}
363 }
364 }
365 }
366 None
367 }
368}
369
370pub struct MirRelationExprDeserializeContext<'a> {
398 inner_ctx: MirScalarExprDeserializeContext,
399 catalog: &'a TestCatalog,
400 scope: Scope,
404}
405
406impl<'a> MirRelationExprDeserializeContext<'a> {
407 pub fn new(catalog: &'a TestCatalog) -> Self {
408 Self {
409 inner_ctx: MirScalarExprDeserializeContext::default(),
410 catalog,
411 scope: Scope::default(),
412 }
413 }
414
415 pub fn list_scope_references(&self) -> impl Iterator<Item = (&String, &RelationType)> {
416 self.scope.iter()
417 }
418
419 fn build_constant<I>(&mut self, stream_iter: &mut I) -> Result<MirRelationExpr, String>
420 where
421 I: Iterator<Item = TokenTree>,
422 {
423 let raw_rows = stream_iter
424 .next()
425 .ok_or_else(|| "Constant is empty".to_string())?;
426 let typ: RelationType = deserialize(stream_iter, "RelationType", self)?;
430
431 let mut rows = Vec::new();
432 match raw_rows {
433 TokenTree::Group(group) => {
434 let mut inner_iter = group.stream().into_iter();
435 while let Some(token) = inner_iter.next() {
436 let row = test_spec_to_row(
437 parse_vec_of_literals(&token)?
438 .iter()
439 .zip_eq(&typ.column_types)
440 .map(|(dat, col_typ)| (&dat[..], &col_typ.scalar_type)),
441 )?;
442 rows.push((row, Diff::ONE));
443 }
444 }
445 invalid => return Err(format!("invalid rows spec for constant `{}`", invalid)),
446 };
447 Ok(MirRelationExpr::Constant {
448 rows: Ok(rows),
449 typ,
450 })
451 }
452
453 fn build_constant_err<I>(&mut self, stream_iter: &mut I) -> Result<MirRelationExpr, String>
454 where
455 I: Iterator<Item = TokenTree>,
456 {
457 let error: EvalError = deserialize(stream_iter, "EvalError", self)?;
458 let typ: RelationType = deserialize(stream_iter, "RelationType", self)?;
459
460 Ok(MirRelationExpr::Constant {
461 rows: Err(error),
462 typ,
463 })
464 }
465
466 fn build_get(&self, token: Option<TokenTree>) -> Result<MirRelationExpr, String> {
467 match token {
468 Some(TokenTree::Ident(ident)) => {
469 let name = ident.to_string();
470 match self.scope.get(&name) {
471 Some((id, typ)) => Ok(MirRelationExpr::Get {
472 id,
473 typ,
474 access_strategy: AccessStrategy::UnknownOrLocal,
475 }),
476 None => match self.catalog.get(&name) {
477 None => Err(format!("no catalog object named {}", name)),
478 Some((id, typ)) => Ok(MirRelationExpr::Get {
479 id: Id::Global(*id),
480 typ: typ.clone(),
481 access_strategy: AccessStrategy::UnknownOrLocal,
482 }),
483 },
484 }
485 }
486 invalid_token => Err(format!(
487 "Invalid get specification {:?}",
488 invalid_token.map(|token_tree| format!("`{}`", token_tree))
489 )),
490 }
491 }
492
493 fn build_let<I>(&mut self, stream_iter: &mut I) -> Result<MirRelationExpr, String>
494 where
495 I: Iterator<Item = TokenTree>,
496 {
497 let name = match stream_iter.next() {
498 Some(TokenTree::Ident(ident)) => Ok(ident.to_string()),
499 invalid_token => Err(format!(
500 "Invalid let specification {:?}",
501 invalid_token.map(|token_tree| format!("`{}`", token_tree))
502 )),
503 }?;
504
505 let value: MirRelationExpr = deserialize(stream_iter, "MirRelationExpr", self)?;
506
507 let (id, prev) = self.scope.insert(&name, value.typ());
508
509 let body: MirRelationExpr = deserialize(stream_iter, "MirRelationExpr", self)?;
510
511 if let Some((old_id, old_val)) = prev {
512 self.scope.set(&name, old_id, old_val);
513 } else {
514 self.scope.remove(&name)
515 }
516
517 Ok(MirRelationExpr::Let {
518 id,
519 value: Box::new(value),
520 body: Box::new(body),
521 })
522 }
523
524 fn build_union<I>(&mut self, stream_iter: &mut I) -> Result<MirRelationExpr, String>
525 where
526 I: Iterator<Item = TokenTree>,
527 {
528 let mut inputs: Vec<MirRelationExpr> =
529 deserialize(stream_iter, "Vec<MirRelationExpr>", self)?;
530 Ok(MirRelationExpr::Union {
531 base: Box::new(inputs.remove(0)),
532 inputs,
533 })
534 }
535
536 fn build_special_mir_if_able<I>(
537 &mut self,
538 first_arg: TokenTree,
539 rest_of_stream: &mut I,
540 ) -> Result<Option<MirRelationExpr>, String>
541 where
542 I: Iterator<Item = TokenTree>,
543 {
544 if let TokenTree::Ident(ident) = first_arg {
545 return Ok(match &ident.to_string().to_lowercase()[..] {
546 "constant" => Some(self.build_constant(rest_of_stream)?),
547 "constant_err" => Some(self.build_constant_err(rest_of_stream)?),
548 "get" => Some(self.build_get(rest_of_stream.next())?),
549 "let" => Some(self.build_let(rest_of_stream)?),
550 "union" => Some(self.build_union(rest_of_stream)?),
551 _ => None,
552 });
553 }
554 Ok(None)
555 }
556}
557
558impl<'a> TestDeserializeContext for MirRelationExprDeserializeContext<'a> {
559 fn override_syntax<I>(
560 &mut self,
561 first_arg: TokenTree,
562 rest_of_stream: &mut I,
563 type_name: &str,
564 ) -> Result<Option<String>, String>
565 where
566 I: Iterator<Item = TokenTree>,
567 {
568 match self
569 .inner_ctx
570 .override_syntax(first_arg.clone(), rest_of_stream, type_name)?
571 {
572 Some(result) => Ok(Some(result)),
573 None => {
574 if type_name == "MirRelationExpr" {
575 if let Some(result) =
576 self.build_special_mir_if_able(first_arg, rest_of_stream)?
577 {
578 return Ok(Some(
579 serde_json::to_string(&result).map_err_to_string_with_causes()?,
580 ));
581 }
582 } else if type_name == "usize" {
583 if let TokenTree::Punct(punct) = first_arg {
584 if punct.as_char() == '#' {
585 match rest_of_stream.next() {
586 Some(TokenTree::Literal(literal)) => {
587 return Ok(Some(literal.to_string()));
588 }
589 invalid => {
590 return Err(format!(
591 "invalid column value {:?}",
592 invalid.map(|token_tree| format!("`{}`", token_tree))
593 ));
594 }
595 }
596 }
597 }
598 }
599 Ok(None)
600 }
601 }
602 }
603
604 fn reverse_syntax_override(&mut self, json: &Value, type_name: &str) -> Option<String> {
605 match self.inner_ctx.reverse_syntax_override(json, type_name) {
606 Some(result) => Some(result),
607 None => {
608 if type_name == "MirRelationExpr" {
609 let map = json.as_object().unwrap();
610 assert_eq!(
612 map.len(),
613 1,
614 "Multivariant instance {:?} found for MirRelationExpr",
615 map
616 );
617 for (variant, data) in map.iter() {
618 let inner_map = data.as_object().unwrap();
619 match &variant[..] {
620 "Let" => {
621 let id: LocalId =
622 serde_json::from_value(inner_map["id"].clone()).unwrap();
623 return Some(format!(
624 "(let {} {} {})",
625 id,
626 serialize::<MirRelationExpr, _>(
627 &inner_map["value"],
628 "MirRelationExpr",
629 self
630 ),
631 serialize::<MirRelationExpr, _>(
632 &inner_map["body"],
633 "MirRelationExpr",
634 self
635 ),
636 ));
637 }
638 "Get" => {
639 let id: Id =
640 serde_json::from_value(inner_map["id"].clone()).unwrap();
641 return Some(match id {
642 Id::Global(global) => {
643 match self.catalog.get_source_name(&global) {
644 Some(source) => format!("(get {})", source),
647 None => {
649 let typ: RelationType = serde_json::from_value(
650 inner_map["typ"].clone(),
651 )
652 .unwrap();
653 self.scope.insert(&id.to_string(), typ);
654 format!("(get {})", id)
655 }
656 }
657 }
658 _ => {
659 format!("(get {})", id)
660 }
661 });
662 }
663 "Constant" => {
664 if let Some(row_vec) = inner_map["rows"].get("Ok") {
665 let mut rows = Vec::new();
666 for inner_array in row_vec.as_array().unwrap() {
667 let row: Row =
668 serde_json::from_value(inner_array[0].clone()).unwrap();
669 let diff = inner_array[1].as_u64().unwrap();
670 for _ in 0..diff {
671 rows.push(format!(
672 "[{}]",
673 separated(" ", row.iter().map(datum_to_test_spec))
674 ))
675 }
676 }
677 return Some(format!(
678 "(constant [{}] {})",
679 separated(" ", rows),
680 serialize::<RelationType, _>(
681 &inner_map["typ"],
682 "RelationType",
683 self
684 )
685 ));
686 } else if let Some(inner_data) = inner_map["rows"].get("Err") {
687 return Some(format!(
688 "(constant_err {} {})",
689 serialize::<EvalError, _>(inner_data, "EvalError", self),
690 serialize::<RelationType, _>(
691 &inner_map["typ"],
692 "RelationType",
693 self
694 )
695 ));
696 } else {
697 unreachable!("unexpected JSON data: {:?}", inner_map);
698 }
699 }
700 "Union" => {
701 let mut inputs = inner_map["inputs"].as_array().unwrap().to_owned();
702 inputs.insert(0, inner_map["base"].clone());
703 return Some(format!(
704 "(union {})",
705 serialize::<Vec<MirRelationExpr>, _>(
706 &Value::Array(inputs),
707 "Vec<MirRelationExpr>",
708 self
709 )
710 ));
711 }
712 _ => {}
713 }
714 }
715 }
716 None
717 }
718 }
719 }
720}
721
722#[derive(Debug, Default)]
725struct Scope {
726 objects: BTreeMap<String, (Id, RelationType)>,
727 names: BTreeMap<Id, String>,
728}
729
730impl Scope {
731 fn insert(&mut self, name: &str, typ: RelationType) -> (LocalId, Option<(Id, RelationType)>) {
732 let old_val = self.get(name);
733 let id = LocalId::new(u64::cast_from(self.objects.len()));
734 self.set(name, Id::Local(id), typ);
735 (id, old_val)
736 }
737
738 fn set(&mut self, name: &str, id: Id, typ: RelationType) {
739 self.objects.insert(name.to_string(), (id, typ));
740 self.names.insert(id, name.to_string());
741 }
742
743 fn remove(&mut self, name: &str) {
744 self.objects.remove(name);
745 }
746
747 fn get(&self, name: &str) -> Option<(Id, RelationType)> {
748 self.objects.get(name).cloned()
749 }
750
751 fn iter(&self) -> impl Iterator<Item = (&String, &RelationType)> {
752 self.objects.iter().map(|(s, (_, typ))| (s, typ))
753 }
754}