1use std::collections::BTreeMap;
11
12use mz_expr::{AccessStrategy, EvalError, Id, LocalId, MirRelationExpr, MirScalarExpr};
13use mz_lowertest::*;
14use mz_ore::cast::CastFrom;
15use mz_ore::result::ResultExt;
16use mz_ore::str::separated;
17use mz_repr::explain::{DummyHumanizer, ExprHumanizer};
18use mz_repr::{ColumnType, Diff, GlobalId, RelationType, Row, ScalarType};
19use mz_repr_test_util::*;
20use proc_macro2::TokenTree;
21use serde::{Deserialize, Serialize};
22use serde_json::Value;
23
24pub fn build_scalar(s: &str) -> Result<MirScalarExpr, String> {
28 deserialize(
29 &mut tokenize(s)?.into_iter(),
30 "MirScalarExpr",
31 &mut MirScalarExprDeserializeContext::default(),
32 )
33}
34
35pub fn build_rel(s: &str, catalog: &TestCatalog) -> Result<MirRelationExpr, String> {
39 deserialize(
40 &mut tokenize(s)?.into_iter(),
41 "MirRelationExpr",
42 &mut MirRelationExprDeserializeContext::new(catalog),
43 )
44}
45
46pub fn json_to_spec(rel_json: &str, catalog: &TestCatalog) -> (String, Vec<String>) {
54 let mut ctx = MirRelationExprDeserializeContext::new(catalog);
55 let spec = serialize::<MirRelationExpr, _>(
56 &serde_json::from_str(rel_json).unwrap(),
57 "MirRelationExpr",
58 &mut ctx,
59 );
60 let mut source_defs = ctx
61 .list_scope_references()
62 .map(|(name, typ)| {
63 format!(
64 "(defsource {} {})",
65 name,
66 serialize_generic::<RelationType>(
67 &serde_json::to_value(typ).unwrap(),
68 "RelationType",
69 )
70 )
71 })
72 .collect::<Vec<_>>();
73 source_defs.sort();
74 (spec, source_defs)
75}
76
77#[derive(Debug, Default)]
82pub struct TestCatalog {
83 objects: BTreeMap<String, (GlobalId, RelationType)>,
84 names: BTreeMap<GlobalId, String>,
85}
86
87#[derive(Debug, Serialize, Deserialize, MzReflect)]
91enum TestCatalogCommand {
92 Defsource { name: String, typ: RelationType },
94}
95
96impl<'a> TestCatalog {
97 pub fn insert(
106 &mut self,
107 name: &str,
108 typ: RelationType,
109 transient: bool,
110 ) -> Result<GlobalId, String> {
111 if self.objects.contains_key(name) {
112 return Err(format!("Object {} already exists in catalog", name));
113 }
114 let id = if transient {
115 GlobalId::Transient(u64::cast_from(self.objects.len()))
116 } else {
117 GlobalId::User(u64::cast_from(self.objects.len()))
118 };
119 self.objects.insert(name.to_string(), (id, typ));
120 self.names.insert(id, name.to_string());
121 Ok(id)
122 }
123
124 fn get(&'a self, name: &str) -> Option<&'a (GlobalId, RelationType)> {
125 self.objects.get(name)
126 }
127
128 pub fn get_source_name(&'a self, id: &GlobalId) -> Option<&'a String> {
130 self.names.get(id)
131 }
132
133 pub fn handle_test_command(&mut self, spec: &str) -> Result<(), String> {
139 let mut stream_iter = tokenize(spec)?.into_iter();
140 while let Some(command) = deserialize_optional_generic::<TestCatalogCommand, _>(
141 &mut stream_iter,
142 "TestCatalogCommand",
143 )? {
144 match command {
145 TestCatalogCommand::Defsource { name, typ } => {
146 self.insert(&name, typ, false)?;
147 }
148 }
149 }
150 Ok(())
151 }
152
153 pub fn remove_transient_objects(&mut self) {
155 self.objects.retain(|_, (id, _)| {
156 if let GlobalId::Transient(_) = id {
157 false
158 } else {
159 true
160 }
161 });
162 self.names.retain(|k, _| {
163 if let GlobalId::Transient(_) = k {
164 false
165 } else {
166 true
167 }
168 });
169 }
170}
171
172impl ExprHumanizer for TestCatalog {
173 fn humanize_id(&self, id: GlobalId) -> Option<String> {
174 self.names.get(&id).map(|s| s.to_string())
175 }
176
177 fn humanize_id_unqualified(&self, id: GlobalId) -> Option<String> {
178 self.names.get(&id).map(|s| s.to_string())
179 }
180
181 fn humanize_id_parts(&self, id: GlobalId) -> Option<Vec<String>> {
182 self.humanize_id_unqualified(id).map(|name| vec![name])
183 }
184
185 fn humanize_scalar_type(&self, ty: &ScalarType, postgres_compat: bool) -> String {
186 DummyHumanizer.humanize_scalar_type(ty, postgres_compat)
187 }
188
189 fn column_names_for_id(&self, _id: GlobalId) -> Option<Vec<String>> {
190 None
191 }
192
193 fn humanize_column(&self, _id: GlobalId, _column: usize) -> Option<String> {
194 None
195 }
196
197 fn id_exists(&self, id: GlobalId) -> bool {
198 self.names.contains_key(&id)
199 }
200}
201
202#[derive(Default)]
215pub struct MirScalarExprDeserializeContext;
216
217impl MirScalarExprDeserializeContext {
218 fn build_column(&self, token: Option<TokenTree>) -> Result<MirScalarExpr, String> {
219 if let Some(TokenTree::Literal(literal)) = token {
220 return Ok(MirScalarExpr::Column(
221 literal
222 .to_string()
223 .parse::<usize>()
224 .map_err_to_string_with_causes()?,
225 ));
226 }
227 Err(format!(
228 "Invalid column specification {:?}",
229 token.map(|token_tree| format!("`{}`", token_tree))
230 ))
231 }
232
233 fn build_literal_if_able<I>(
234 &self,
235 first_arg: TokenTree,
236 rest_of_stream: &mut I,
237 ) -> Result<Option<MirScalarExpr>, String>
238 where
239 I: Iterator<Item = TokenTree>,
240 {
241 match &first_arg {
242 TokenTree::Ident(i) if i.to_string().eq_ignore_ascii_case("ok") => {
243 let first_arg = if let Some(first_arg) = rest_of_stream.next() {
245 first_arg
246 } else {
247 return Err(format!("expected literal after Ident: `{}`", i));
248 };
249 match self.build_literal_ok_if_able(first_arg, rest_of_stream) {
250 Ok(Some(l)) => Ok(Some(l)),
251 _ => Err(format!("expected literal after Ident: `{}`", i)),
252 }
253 }
254 TokenTree::Ident(i) if i.to_string().eq_ignore_ascii_case("err") => {
255 let error = deserialize_generic(rest_of_stream, "EvalError")?;
256 let typ: Option<ScalarType> =
257 deserialize_optional_generic(rest_of_stream, "ScalarType")?;
258 Ok(Some(MirScalarExpr::literal(
259 Err(error),
260 typ.unwrap_or(ScalarType::Bool),
261 )))
262 }
263 _ => self.build_literal_ok_if_able(first_arg, rest_of_stream),
264 }
265 }
266
267 fn build_literal_ok_if_able<I>(
268 &self,
269 first_arg: TokenTree,
270 rest_of_stream: &mut I,
271 ) -> Result<Option<MirScalarExpr>, String>
272 where
273 I: Iterator<Item = TokenTree>,
274 {
275 match extract_literal_string(&first_arg, rest_of_stream)? {
276 Some(litval) => {
277 let littyp = get_scalar_type_or_default(&litval[..], rest_of_stream)?;
278 Ok(Some(MirScalarExpr::Literal(
279 Ok(test_spec_to_row(std::iter::once((&litval[..], &littyp)))?),
280 littyp.nullable(matches!(&litval[..], "null")),
281 )))
282 }
283 None => Ok(None),
284 }
285 }
286}
287
288impl TestDeserializeContext for MirScalarExprDeserializeContext {
289 fn override_syntax<I>(
290 &mut self,
291 first_arg: TokenTree,
292 rest_of_stream: &mut I,
293 type_name: &str,
294 ) -> Result<Option<String>, String>
295 where
296 I: Iterator<Item = TokenTree>,
297 {
298 let result = if type_name == "MirScalarExpr" {
299 match first_arg {
300 TokenTree::Punct(punct) if punct.as_char() == '#' => {
301 Some(self.build_column(rest_of_stream.next())?)
302 }
303 TokenTree::Group(_) => None,
304 symbol => self.build_literal_if_able(symbol, rest_of_stream)?,
305 }
306 } else {
307 None
308 };
309 match result {
310 Some(result) => Ok(Some(
311 serde_json::to_string(&result).map_err_to_string_with_causes()?,
312 )),
313 None => Ok(None),
314 }
315 }
316
317 fn reverse_syntax_override(&mut self, json: &Value, type_name: &str) -> Option<String> {
318 if type_name == "MirScalarExpr" {
319 let map = json.as_object().unwrap();
320 assert_eq!(map.len(), 1);
322 for (variant, data) in map.iter() {
323 match &variant[..] {
324 "Column" => return Some(format!("#{}", data.as_u64().unwrap())),
325 "Literal" => {
326 let column_type: ColumnType =
327 serde_json::from_value(data.as_array().unwrap()[1].clone()).unwrap();
328 let obj = data.as_array().unwrap()[0].as_object().unwrap();
329 if let Some(inner_data) = obj.get("Ok") {
330 let row: Row = serde_json::from_value(inner_data.clone()).unwrap();
331 let result = format!(
332 "({} {})",
333 datum_to_test_spec(row.unpack_first()),
334 serialize::<ScalarType, _>(
335 &serde_json::to_value(&column_type.scalar_type).unwrap(),
336 "ScalarType",
337 self
338 )
339 );
340 return Some(result);
341 } else if let Some(inner_data) = obj.get("Err") {
342 let result = format!(
343 "(err {} {})",
344 serialize::<EvalError, _>(inner_data, "EvalError", self),
345 serialize::<ScalarType, _>(
346 &serde_json::to_value(&column_type.scalar_type).unwrap(),
347 "ScalarType",
348 self
349 ),
350 );
351 return Some(result);
352 } else {
353 unreachable!("unexpected JSON data: {:?}", obj);
354 }
355 }
356 _ => {}
357 }
358 }
359 }
360 None
361 }
362}
363
364pub struct MirRelationExprDeserializeContext<'a> {
392 inner_ctx: MirScalarExprDeserializeContext,
393 catalog: &'a TestCatalog,
394 scope: Scope,
398}
399
400impl<'a> MirRelationExprDeserializeContext<'a> {
401 pub fn new(catalog: &'a TestCatalog) -> Self {
402 Self {
403 inner_ctx: MirScalarExprDeserializeContext::default(),
404 catalog,
405 scope: Scope::default(),
406 }
407 }
408
409 pub fn list_scope_references(&self) -> impl Iterator<Item = (&String, &RelationType)> {
410 self.scope.iter()
411 }
412
413 fn build_constant<I>(&mut self, stream_iter: &mut I) -> Result<MirRelationExpr, String>
414 where
415 I: Iterator<Item = TokenTree>,
416 {
417 let raw_rows = stream_iter
418 .next()
419 .ok_or_else(|| "Constant is empty".to_string())?;
420 let typ: RelationType = deserialize(stream_iter, "RelationType", self)?;
424
425 let mut rows = Vec::new();
426 match raw_rows {
427 TokenTree::Group(group) => {
428 let mut inner_iter = group.stream().into_iter();
429 while let Some(token) = inner_iter.next() {
430 let row = test_spec_to_row(
431 parse_vec_of_literals(&token)?
432 .iter()
433 .zip(&typ.column_types)
434 .map(|(dat, col_typ)| (&dat[..], &col_typ.scalar_type)),
435 )?;
436 rows.push((row, Diff::ONE));
437 }
438 }
439 invalid => return Err(format!("invalid rows spec for constant `{}`", invalid)),
440 };
441 Ok(MirRelationExpr::Constant {
442 rows: Ok(rows),
443 typ,
444 })
445 }
446
447 fn build_constant_err<I>(&mut self, stream_iter: &mut I) -> Result<MirRelationExpr, String>
448 where
449 I: Iterator<Item = TokenTree>,
450 {
451 let error: EvalError = deserialize(stream_iter, "EvalError", self)?;
452 let typ: RelationType = deserialize(stream_iter, "RelationType", self)?;
453
454 Ok(MirRelationExpr::Constant {
455 rows: Err(error),
456 typ,
457 })
458 }
459
460 fn build_get(&self, token: Option<TokenTree>) -> Result<MirRelationExpr, String> {
461 match token {
462 Some(TokenTree::Ident(ident)) => {
463 let name = ident.to_string();
464 match self.scope.get(&name) {
465 Some((id, typ)) => Ok(MirRelationExpr::Get {
466 id,
467 typ,
468 access_strategy: AccessStrategy::UnknownOrLocal,
469 }),
470 None => match self.catalog.get(&name) {
471 None => Err(format!("no catalog object named {}", name)),
472 Some((id, typ)) => Ok(MirRelationExpr::Get {
473 id: Id::Global(*id),
474 typ: typ.clone(),
475 access_strategy: AccessStrategy::UnknownOrLocal,
476 }),
477 },
478 }
479 }
480 invalid_token => Err(format!(
481 "Invalid get specification {:?}",
482 invalid_token.map(|token_tree| format!("`{}`", token_tree))
483 )),
484 }
485 }
486
487 fn build_let<I>(&mut self, stream_iter: &mut I) -> Result<MirRelationExpr, String>
488 where
489 I: Iterator<Item = TokenTree>,
490 {
491 let name = match stream_iter.next() {
492 Some(TokenTree::Ident(ident)) => Ok(ident.to_string()),
493 invalid_token => Err(format!(
494 "Invalid let specification {:?}",
495 invalid_token.map(|token_tree| format!("`{}`", token_tree))
496 )),
497 }?;
498
499 let value: MirRelationExpr = deserialize(stream_iter, "MirRelationExpr", self)?;
500
501 let (id, prev) = self.scope.insert(&name, value.typ());
502
503 let body: MirRelationExpr = deserialize(stream_iter, "MirRelationExpr", self)?;
504
505 if let Some((old_id, old_val)) = prev {
506 self.scope.set(&name, old_id, old_val);
507 } else {
508 self.scope.remove(&name)
509 }
510
511 Ok(MirRelationExpr::Let {
512 id,
513 value: Box::new(value),
514 body: Box::new(body),
515 })
516 }
517
518 fn build_union<I>(&mut self, stream_iter: &mut I) -> Result<MirRelationExpr, String>
519 where
520 I: Iterator<Item = TokenTree>,
521 {
522 let mut inputs: Vec<MirRelationExpr> =
523 deserialize(stream_iter, "Vec<MirRelationExpr>", self)?;
524 Ok(MirRelationExpr::Union {
525 base: Box::new(inputs.remove(0)),
526 inputs,
527 })
528 }
529
530 fn build_special_mir_if_able<I>(
531 &mut self,
532 first_arg: TokenTree,
533 rest_of_stream: &mut I,
534 ) -> Result<Option<MirRelationExpr>, String>
535 where
536 I: Iterator<Item = TokenTree>,
537 {
538 if let TokenTree::Ident(ident) = first_arg {
539 return Ok(match &ident.to_string().to_lowercase()[..] {
540 "constant" => Some(self.build_constant(rest_of_stream)?),
541 "constant_err" => Some(self.build_constant_err(rest_of_stream)?),
542 "get" => Some(self.build_get(rest_of_stream.next())?),
543 "let" => Some(self.build_let(rest_of_stream)?),
544 "union" => Some(self.build_union(rest_of_stream)?),
545 _ => None,
546 });
547 }
548 Ok(None)
549 }
550}
551
552impl<'a> TestDeserializeContext for MirRelationExprDeserializeContext<'a> {
553 fn override_syntax<I>(
554 &mut self,
555 first_arg: TokenTree,
556 rest_of_stream: &mut I,
557 type_name: &str,
558 ) -> Result<Option<String>, String>
559 where
560 I: Iterator<Item = TokenTree>,
561 {
562 match self
563 .inner_ctx
564 .override_syntax(first_arg.clone(), rest_of_stream, type_name)?
565 {
566 Some(result) => Ok(Some(result)),
567 None => {
568 if type_name == "MirRelationExpr" {
569 if let Some(result) =
570 self.build_special_mir_if_able(first_arg, rest_of_stream)?
571 {
572 return Ok(Some(
573 serde_json::to_string(&result).map_err_to_string_with_causes()?,
574 ));
575 }
576 } else if type_name == "usize" {
577 if let TokenTree::Punct(punct) = first_arg {
578 if punct.as_char() == '#' {
579 match rest_of_stream.next() {
580 Some(TokenTree::Literal(literal)) => {
581 return Ok(Some(literal.to_string()));
582 }
583 invalid => {
584 return Err(format!(
585 "invalid column value {:?}",
586 invalid.map(|token_tree| format!("`{}`", token_tree))
587 ));
588 }
589 }
590 }
591 }
592 }
593 Ok(None)
594 }
595 }
596 }
597
598 fn reverse_syntax_override(&mut self, json: &Value, type_name: &str) -> Option<String> {
599 match self.inner_ctx.reverse_syntax_override(json, type_name) {
600 Some(result) => Some(result),
601 None => {
602 if type_name == "MirRelationExpr" {
603 let map = json.as_object().unwrap();
604 assert_eq!(
606 map.len(),
607 1,
608 "Multivariant instance {:?} found for MirRelationExpr",
609 map
610 );
611 for (variant, data) in map.iter() {
612 let inner_map = data.as_object().unwrap();
613 match &variant[..] {
614 "Let" => {
615 let id: LocalId =
616 serde_json::from_value(inner_map["id"].clone()).unwrap();
617 return Some(format!(
618 "(let {} {} {})",
619 id,
620 serialize::<MirRelationExpr, _>(
621 &inner_map["value"],
622 "MirRelationExpr",
623 self
624 ),
625 serialize::<MirRelationExpr, _>(
626 &inner_map["body"],
627 "MirRelationExpr",
628 self
629 ),
630 ));
631 }
632 "Get" => {
633 let id: Id =
634 serde_json::from_value(inner_map["id"].clone()).unwrap();
635 return Some(match id {
636 Id::Global(global) => {
637 match self.catalog.get_source_name(&global) {
638 Some(source) => format!("(get {})", source),
641 None => {
643 let typ: RelationType = serde_json::from_value(
644 inner_map["typ"].clone(),
645 )
646 .unwrap();
647 self.scope.insert(&id.to_string(), typ);
648 format!("(get {})", id)
649 }
650 }
651 }
652 _ => {
653 format!("(get {})", id)
654 }
655 });
656 }
657 "Constant" => {
658 if let Some(row_vec) = inner_map["rows"].get("Ok") {
659 let mut rows = Vec::new();
660 for inner_array in row_vec.as_array().unwrap() {
661 let row: Row =
662 serde_json::from_value(inner_array[0].clone()).unwrap();
663 let diff = inner_array[1].as_u64().unwrap();
664 for _ in 0..diff {
665 rows.push(format!(
666 "[{}]",
667 separated(" ", row.iter().map(datum_to_test_spec))
668 ))
669 }
670 }
671 return Some(format!(
672 "(constant [{}] {})",
673 separated(" ", rows),
674 serialize::<RelationType, _>(
675 &inner_map["typ"],
676 "RelationType",
677 self
678 )
679 ));
680 } else if let Some(inner_data) = inner_map["rows"].get("Err") {
681 return Some(format!(
682 "(constant_err {} {})",
683 serialize::<EvalError, _>(inner_data, "EvalError", self),
684 serialize::<RelationType, _>(
685 &inner_map["typ"],
686 "RelationType",
687 self
688 )
689 ));
690 } else {
691 unreachable!("unexpected JSON data: {:?}", inner_map);
692 }
693 }
694 "Union" => {
695 let mut inputs = inner_map["inputs"].as_array().unwrap().to_owned();
696 inputs.insert(0, inner_map["base"].clone());
697 return Some(format!(
698 "(union {})",
699 serialize::<Vec<MirRelationExpr>, _>(
700 &Value::Array(inputs),
701 "Vec<MirRelationExpr>",
702 self
703 )
704 ));
705 }
706 _ => {}
707 }
708 }
709 }
710 None
711 }
712 }
713 }
714}
715
716#[derive(Debug, Default)]
719struct Scope {
720 objects: BTreeMap<String, (Id, RelationType)>,
721 names: BTreeMap<Id, String>,
722}
723
724impl Scope {
725 fn insert(&mut self, name: &str, typ: RelationType) -> (LocalId, Option<(Id, RelationType)>) {
726 let old_val = self.get(name);
727 let id = LocalId::new(u64::cast_from(self.objects.len()));
728 self.set(name, Id::Local(id), typ);
729 (id, old_val)
730 }
731
732 fn set(&mut self, name: &str, id: Id, typ: RelationType) {
733 self.objects.insert(name.to_string(), (id, typ));
734 self.names.insert(id, name.to_string());
735 }
736
737 fn remove(&mut self, name: &str) {
738 self.objects.remove(name);
739 }
740
741 fn get(&self, name: &str) -> Option<(Id, RelationType)> {
742 self.objects.get(name).cloned()
743 }
744
745 fn iter(&self) -> impl Iterator<Item = (&String, &RelationType)> {
746 self.objects.iter().map(|(s, (_, typ))| (s, typ))
747 }
748}