1use std::collections::BTreeMap;
11
12use itertools::Itertools;
13use mz_expr::{AccessStrategy, EvalError, Id, LocalId, MirRelationExpr, MirScalarExpr};
14use mz_lowertest::*;
15use mz_ore::cast::CastFrom;
16use mz_ore::result::ResultExt;
17use mz_ore::str::separated;
18use mz_repr::explain::{DummyHumanizer, ExprHumanizer};
19use mz_repr::{
20 Diff, GlobalId, ReprColumnType, ReprRelationType, ReprScalarType, Row, SqlRelationType,
21 SqlScalarType,
22};
23use mz_repr_test_util::*;
24use proc_macro2::TokenTree;
25use serde::{Deserialize, Serialize};
26use serde_json::Value;
27
28pub fn build_scalar(s: &str) -> Result<MirScalarExpr, String> {
32 deserialize(
33 &mut tokenize(s)?.into_iter(),
34 "MirScalarExpr",
35 &mut MirScalarExprDeserializeContext::default(),
36 )
37}
38
39pub fn build_rel(s: &str, catalog: &TestCatalog) -> Result<MirRelationExpr, String> {
43 deserialize(
44 &mut tokenize(s)?.into_iter(),
45 "MirRelationExpr",
46 &mut MirRelationExprDeserializeContext::new(catalog),
47 )
48}
49
50pub fn json_to_spec(rel_json: &str, catalog: &TestCatalog) -> (String, Vec<String>) {
58 let mut ctx = MirRelationExprDeserializeContext::new(catalog);
59 let spec = serialize::<MirRelationExpr, _>(
60 &serde_json::from_str(rel_json).unwrap(),
61 "MirRelationExpr",
62 &mut ctx,
63 );
64 let mut source_defs = ctx
65 .list_scope_references()
66 .map(|(name, repr_typ)| {
67 let sql_typ = SqlRelationType::from_repr(repr_typ);
68 format!(
69 "(defsource {} {})",
70 name,
71 serialize_generic::<SqlRelationType>(
72 &serde_json::to_value(&sql_typ).unwrap(),
73 "SqlRelationType",
74 )
75 )
76 })
77 .collect::<Vec<_>>();
78 source_defs.sort();
79 (spec, source_defs)
80}
81
82#[derive(Debug, Default)]
87pub struct TestCatalog {
88 objects: BTreeMap<String, (GlobalId, SqlRelationType)>,
89 names: BTreeMap<GlobalId, String>,
90}
91
92#[derive(Debug, Serialize, Deserialize, MzReflect)]
96enum TestCatalogCommand {
97 Defsource { name: String, typ: SqlRelationType },
99}
100
101impl<'a> TestCatalog {
102 pub fn insert(
111 &mut self,
112 name: &str,
113 typ: SqlRelationType,
114 transient: bool,
115 ) -> Result<GlobalId, String> {
116 if self.objects.contains_key(name) {
117 return Err(format!("Object {} already exists in catalog", name));
118 }
119 let id = if transient {
120 GlobalId::Transient(u64::cast_from(self.objects.len()))
121 } else {
122 GlobalId::User(u64::cast_from(self.objects.len()))
123 };
124 self.objects.insert(name.to_string(), (id, typ));
125 self.names.insert(id, name.to_string());
126 Ok(id)
127 }
128
129 fn get(&'a self, name: &str) -> Option<&'a (GlobalId, SqlRelationType)> {
130 self.objects.get(name)
131 }
132
133 pub fn get_source_name(&'a self, id: &GlobalId) -> Option<&'a String> {
135 self.names.get(id)
136 }
137
138 pub fn handle_test_command(&mut self, spec: &str) -> Result<(), String> {
144 let mut stream_iter = tokenize(spec)?.into_iter();
145 while let Some(command) = deserialize_optional_generic::<TestCatalogCommand, _>(
146 &mut stream_iter,
147 "TestCatalogCommand",
148 )? {
149 match command {
150 TestCatalogCommand::Defsource { name, typ } => {
151 self.insert(&name, typ, false)?;
152 }
153 }
154 }
155 Ok(())
156 }
157
158 pub fn remove_transient_objects(&mut self) {
160 self.objects.retain(|_, (id, _)| {
161 if let GlobalId::Transient(_) = id {
162 false
163 } else {
164 true
165 }
166 });
167 self.names.retain(|k, _| {
168 if let GlobalId::Transient(_) = k {
169 false
170 } else {
171 true
172 }
173 });
174 }
175}
176
177impl ExprHumanizer for TestCatalog {
178 fn humanize_id(&self, id: GlobalId) -> Option<String> {
179 self.names.get(&id).map(|s| s.to_string())
180 }
181
182 fn humanize_id_unqualified(&self, id: GlobalId) -> Option<String> {
183 self.names.get(&id).map(|s| s.to_string())
184 }
185
186 fn humanize_id_parts(&self, id: GlobalId) -> Option<Vec<String>> {
187 self.humanize_id_unqualified(id).map(|name| vec![name])
188 }
189
190 fn humanize_sql_scalar_type(&self, ty: &SqlScalarType, postgres_compat: bool) -> String {
191 DummyHumanizer.humanize_sql_scalar_type(ty, postgres_compat)
192 }
193
194 fn column_names_for_id(&self, _id: GlobalId) -> Option<Vec<String>> {
195 None
196 }
197
198 fn humanize_column(&self, _id: GlobalId, _column: usize) -> Option<String> {
199 None
200 }
201
202 fn id_exists(&self, id: GlobalId) -> bool {
203 self.names.contains_key(&id)
204 }
205}
206
207#[derive(Default)]
220pub struct MirScalarExprDeserializeContext;
221
222impl MirScalarExprDeserializeContext {
223 fn build_column(&self, token: Option<TokenTree>) -> Result<MirScalarExpr, String> {
224 if let Some(TokenTree::Literal(literal)) = token {
225 return Ok(MirScalarExpr::column(
226 literal
227 .to_string()
228 .parse::<usize>()
229 .map_err_to_string_with_causes()?,
230 ));
231 }
232 Err(format!(
233 "Invalid column specification {:?}",
234 token.map(|token_tree| format!("`{}`", token_tree))
235 ))
236 }
237
238 fn build_literal_if_able<I>(
239 &self,
240 first_arg: TokenTree,
241 rest_of_stream: &mut I,
242 ) -> Result<Option<MirScalarExpr>, String>
243 where
244 I: Iterator<Item = TokenTree>,
245 {
246 match &first_arg {
247 TokenTree::Ident(i) if i.to_string().eq_ignore_ascii_case("ok") => {
248 let first_arg = if let Some(first_arg) = rest_of_stream.next() {
250 first_arg
251 } else {
252 return Err(format!("expected literal after Ident: `{}`", i));
253 };
254 match self.build_literal_ok_if_able(first_arg, rest_of_stream) {
255 Ok(Some(l)) => Ok(Some(l)),
256 _ => Err(format!("expected literal after Ident: `{}`", i)),
257 }
258 }
259 TokenTree::Ident(i) if i.to_string().eq_ignore_ascii_case("err") => {
260 let error = deserialize_generic(rest_of_stream, "EvalError")?;
261 let typ: Option<SqlScalarType> =
262 deserialize_optional_generic(rest_of_stream, "SqlScalarType")?;
263 Ok(Some(MirScalarExpr::literal(
264 Err(error),
265 ReprScalarType::from(&typ.unwrap_or(SqlScalarType::Bool)),
266 )))
267 }
268 _ => self.build_literal_ok_if_able(first_arg, rest_of_stream),
269 }
270 }
271
272 fn build_literal_ok_if_able<I>(
273 &self,
274 first_arg: TokenTree,
275 rest_of_stream: &mut I,
276 ) -> Result<Option<MirScalarExpr>, String>
277 where
278 I: Iterator<Item = TokenTree>,
279 {
280 match extract_literal_string(&first_arg, rest_of_stream)? {
281 Some(litval) => {
282 let littyp = get_scalar_type_or_default(&litval[..], rest_of_stream)?;
283 Ok(Some(MirScalarExpr::literal_from_single_element_row(
284 test_spec_to_row(std::iter::once((&litval[..], &littyp)))?,
285 ReprScalarType::from(&littyp),
286 )))
287 }
288 None => Ok(None),
289 }
290 }
291}
292
293impl TestDeserializeContext for MirScalarExprDeserializeContext {
294 fn override_syntax<I>(
295 &mut self,
296 first_arg: TokenTree,
297 rest_of_stream: &mut I,
298 type_name: &str,
299 ) -> Result<Option<String>, String>
300 where
301 I: Iterator<Item = TokenTree>,
302 {
303 let result = if type_name == "MirScalarExpr" {
304 match first_arg {
305 TokenTree::Punct(punct) if punct.as_char() == '#' => {
306 Some(self.build_column(rest_of_stream.next())?)
307 }
308 TokenTree::Group(_) => None,
309 symbol => self.build_literal_if_able(symbol, rest_of_stream)?,
310 }
311 } else {
312 None
313 };
314 match result {
315 Some(result) => Ok(Some(
316 serde_json::to_string(&result).map_err_to_string_with_causes()?,
317 )),
318 None => Ok(None),
319 }
320 }
321
322 fn reverse_syntax_override(&mut self, json: &Value, type_name: &str) -> Option<String> {
323 if type_name == "MirScalarExpr" {
324 let map = json.as_object().unwrap();
325 assert_eq!(map.len(), 1);
327 for (variant, data) in map.iter() {
328 match &variant[..] {
329 "Column" => {
330 return Some(format!(
331 "#{}",
332 data.as_array().unwrap()[0].as_u64().unwrap()
333 ));
334 }
335 "Literal" => {
336 let column_type: ReprColumnType =
337 serde_json::from_value(data.as_array().unwrap()[1].clone()).unwrap();
338 let obj = data.as_array().unwrap()[0].as_object().unwrap();
339 if let Some(inner_data) = obj.get("Ok") {
340 let row: Row = serde_json::from_value(inner_data.clone()).unwrap();
341 let result = format!(
342 "({} {})",
343 datum_to_test_spec(row.unpack_first()),
344 serialize::<ReprScalarType, _>(
345 &serde_json::to_value(&column_type.scalar_type).unwrap(),
346 "ReprScalarType",
347 self
348 )
349 );
350 return Some(result);
351 } else if let Some(inner_data) = obj.get("Err") {
352 let result = format!(
353 "(err {} {})",
354 serialize::<EvalError, _>(inner_data, "EvalError", self),
355 serialize::<ReprScalarType, _>(
356 &serde_json::to_value(&column_type.scalar_type).unwrap(),
357 "ReprScalarType",
358 self
359 ),
360 );
361 return Some(result);
362 } else {
363 unreachable!("unexpected JSON data: {:?}", obj);
364 }
365 }
366 _ => {}
367 }
368 }
369 }
370 None
371 }
372}
373
374pub struct MirRelationExprDeserializeContext<'a> {
402 inner_ctx: MirScalarExprDeserializeContext,
403 catalog: &'a TestCatalog,
404 scope: Scope,
408}
409
410impl<'a> MirRelationExprDeserializeContext<'a> {
411 pub fn new(catalog: &'a TestCatalog) -> Self {
412 Self {
413 inner_ctx: MirScalarExprDeserializeContext::default(),
414 catalog,
415 scope: Scope::default(),
416 }
417 }
418
419 pub fn list_scope_references(&self) -> impl Iterator<Item = (&String, &ReprRelationType)> {
420 self.scope.iter()
421 }
422
423 fn build_constant<I>(&mut self, stream_iter: &mut I) -> Result<MirRelationExpr, String>
424 where
425 I: Iterator<Item = TokenTree>,
426 {
427 let raw_rows = stream_iter
428 .next()
429 .ok_or_else(|| "Constant is empty".to_string())?;
430 let typ: SqlRelationType = deserialize(stream_iter, "SqlRelationType", self)?;
434
435 let mut rows = Vec::new();
436 match raw_rows {
437 TokenTree::Group(group) => {
438 let mut inner_iter = group.stream().into_iter();
439 while let Some(token) = inner_iter.next() {
440 let row = test_spec_to_row(
441 parse_vec_of_literals(&token)?
442 .iter()
443 .zip_eq(&typ.column_types)
444 .map(|(dat, col_typ)| (&dat[..], &col_typ.scalar_type)),
445 )?;
446 rows.push((row, Diff::ONE));
447 }
448 }
449 invalid => return Err(format!("invalid rows spec for constant `{}`", invalid)),
450 };
451 Ok(MirRelationExpr::Constant {
452 rows: Ok(rows),
453 typ: ReprRelationType::from(&typ),
454 })
455 }
456
457 fn build_constant_err<I>(&mut self, stream_iter: &mut I) -> Result<MirRelationExpr, String>
458 where
459 I: Iterator<Item = TokenTree>,
460 {
461 let error: EvalError = deserialize(stream_iter, "EvalError", self)?;
462 let typ: SqlRelationType = deserialize(stream_iter, "SqlRelationType", self)?;
463
464 Ok(MirRelationExpr::Constant {
465 rows: Err(error),
466 typ: ReprRelationType::from(&typ),
467 })
468 }
469
470 fn build_get(&self, token: Option<TokenTree>) -> Result<MirRelationExpr, String> {
471 match token {
472 Some(TokenTree::Ident(ident)) => {
473 let name = ident.to_string();
474 match self.scope.get(&name) {
475 Some((id, typ)) => Ok(MirRelationExpr::Get {
476 id,
477 typ,
478 access_strategy: AccessStrategy::UnknownOrLocal,
479 }),
480 None => match self.catalog.get(&name) {
481 None => Err(format!("no catalog object named {}", name)),
482 Some((id, typ)) => Ok(MirRelationExpr::Get {
483 id: Id::Global(*id),
484 typ: ReprRelationType::from(typ),
485 access_strategy: AccessStrategy::UnknownOrLocal,
486 }),
487 },
488 }
489 }
490 invalid_token => Err(format!(
491 "Invalid get specification {:?}",
492 invalid_token.map(|token_tree| format!("`{}`", token_tree))
493 )),
494 }
495 }
496
497 fn build_let<I>(&mut self, stream_iter: &mut I) -> Result<MirRelationExpr, String>
498 where
499 I: Iterator<Item = TokenTree>,
500 {
501 let name = match stream_iter.next() {
502 Some(TokenTree::Ident(ident)) => Ok(ident.to_string()),
503 invalid_token => Err(format!(
504 "Invalid let specification {:?}",
505 invalid_token.map(|token_tree| format!("`{}`", token_tree))
506 )),
507 }?;
508
509 let value: MirRelationExpr = deserialize(stream_iter, "MirRelationExpr", self)?;
510
511 let (id, prev) = self.scope.insert(&name, value.typ());
512
513 let body: MirRelationExpr = deserialize(stream_iter, "MirRelationExpr", self)?;
514
515 if let Some((old_id, old_val)) = prev {
516 self.scope.set(&name, old_id, old_val);
517 } else {
518 self.scope.remove(&name)
519 }
520
521 Ok(MirRelationExpr::Let {
522 id,
523 value: Box::new(value),
524 body: Box::new(body),
525 })
526 }
527
528 fn build_union<I>(&mut self, stream_iter: &mut I) -> Result<MirRelationExpr, String>
529 where
530 I: Iterator<Item = TokenTree>,
531 {
532 let mut inputs: Vec<MirRelationExpr> =
533 deserialize(stream_iter, "Vec<MirRelationExpr>", self)?;
534 Ok(MirRelationExpr::Union {
535 base: Box::new(inputs.remove(0)),
536 inputs,
537 })
538 }
539
540 fn build_special_mir_if_able<I>(
541 &mut self,
542 first_arg: TokenTree,
543 rest_of_stream: &mut I,
544 ) -> Result<Option<MirRelationExpr>, String>
545 where
546 I: Iterator<Item = TokenTree>,
547 {
548 if let TokenTree::Ident(ident) = first_arg {
549 return Ok(match &ident.to_string().to_lowercase()[..] {
550 "constant" => Some(self.build_constant(rest_of_stream)?),
551 "constant_err" => Some(self.build_constant_err(rest_of_stream)?),
552 "get" => Some(self.build_get(rest_of_stream.next())?),
553 "let" => Some(self.build_let(rest_of_stream)?),
554 "union" => Some(self.build_union(rest_of_stream)?),
555 _ => None,
556 });
557 }
558 Ok(None)
559 }
560}
561
562impl<'a> TestDeserializeContext for MirRelationExprDeserializeContext<'a> {
563 fn override_syntax<I>(
564 &mut self,
565 first_arg: TokenTree,
566 rest_of_stream: &mut I,
567 type_name: &str,
568 ) -> Result<Option<String>, String>
569 where
570 I: Iterator<Item = TokenTree>,
571 {
572 match self
573 .inner_ctx
574 .override_syntax(first_arg.clone(), rest_of_stream, type_name)?
575 {
576 Some(result) => Ok(Some(result)),
577 None => {
578 if type_name == "MirRelationExpr" {
579 if let Some(result) =
580 self.build_special_mir_if_able(first_arg, rest_of_stream)?
581 {
582 return Ok(Some(
583 serde_json::to_string(&result).map_err_to_string_with_causes()?,
584 ));
585 }
586 } else if type_name == "usize" {
587 if let TokenTree::Punct(punct) = first_arg {
588 if punct.as_char() == '#' {
589 match rest_of_stream.next() {
590 Some(TokenTree::Literal(literal)) => {
591 return Ok(Some(literal.to_string()));
592 }
593 invalid => {
594 return Err(format!(
595 "invalid column value {:?}",
596 invalid.map(|token_tree| format!("`{}`", token_tree))
597 ));
598 }
599 }
600 }
601 }
602 }
603 Ok(None)
604 }
605 }
606 }
607
608 fn reverse_syntax_override(&mut self, json: &Value, type_name: &str) -> Option<String> {
609 match self.inner_ctx.reverse_syntax_override(json, type_name) {
610 Some(result) => Some(result),
611 None => {
612 if type_name == "MirRelationExpr" {
613 let map = json.as_object().unwrap();
614 assert_eq!(
616 map.len(),
617 1,
618 "Multivariant instance {:?} found for MirRelationExpr",
619 map
620 );
621 for (variant, data) in map.iter() {
622 let inner_map = data.as_object().unwrap();
623 match &variant[..] {
624 "Let" => {
625 let id: LocalId =
626 serde_json::from_value(inner_map["id"].clone()).unwrap();
627 return Some(format!(
628 "(let {} {} {})",
629 id,
630 serialize::<MirRelationExpr, _>(
631 &inner_map["value"],
632 "MirRelationExpr",
633 self
634 ),
635 serialize::<MirRelationExpr, _>(
636 &inner_map["body"],
637 "MirRelationExpr",
638 self
639 ),
640 ));
641 }
642 "Get" => {
643 let id: Id =
644 serde_json::from_value(inner_map["id"].clone()).unwrap();
645 return Some(match id {
646 Id::Global(global) => {
647 match self.catalog.get_source_name(&global) {
648 Some(source) => format!("(get {})", source),
651 None => {
653 let repr_typ: ReprRelationType =
654 serde_json::from_value(
655 inner_map["typ"].clone(),
656 )
657 .unwrap();
658 self.scope.insert(&id.to_string(), repr_typ);
659 format!("(get {})", id)
660 }
661 }
662 }
663 _ => {
664 format!("(get {})", id)
665 }
666 });
667 }
668 "Constant" => {
669 if let Some(row_vec) = inner_map["rows"].get("Ok") {
670 let mut rows = Vec::new();
671 for inner_array in row_vec.as_array().unwrap() {
672 let row: Row =
673 serde_json::from_value(inner_array[0].clone()).unwrap();
674 let diff = inner_array[1].as_u64().unwrap();
675 for _ in 0..diff {
676 rows.push(format!(
677 "[{}]",
678 separated(" ", row.iter().map(datum_to_test_spec))
679 ))
680 }
681 }
682 let repr_typ: ReprRelationType =
683 serde_json::from_value(inner_map["typ"].clone()).unwrap();
684 let sql_typ = SqlRelationType::from_repr(&repr_typ);
685 let sql_typ_json = serde_json::to_value(&sql_typ).unwrap();
686 return Some(format!(
687 "(constant [{}] {})",
688 separated(" ", rows),
689 serialize::<SqlRelationType, _>(
690 &sql_typ_json,
691 "SqlRelationType",
692 self
693 )
694 ));
695 } else if let Some(inner_data) = inner_map["rows"].get("Err") {
696 let repr_typ: ReprRelationType =
697 serde_json::from_value(inner_map["typ"].clone()).unwrap();
698 let sql_typ = SqlRelationType::from_repr(&repr_typ);
699 let sql_typ_json = serde_json::to_value(&sql_typ).unwrap();
700 return Some(format!(
701 "(constant_err {} {})",
702 serialize::<EvalError, _>(inner_data, "EvalError", self),
703 serialize::<SqlRelationType, _>(
704 &sql_typ_json,
705 "SqlRelationType",
706 self
707 )
708 ));
709 } else {
710 unreachable!("unexpected JSON data: {:?}", inner_map);
711 }
712 }
713 "Union" => {
714 let mut inputs = inner_map["inputs"].as_array().unwrap().to_owned();
715 inputs.insert(0, inner_map["base"].clone());
716 return Some(format!(
717 "(union {})",
718 serialize::<Vec<MirRelationExpr>, _>(
719 &Value::Array(inputs),
720 "Vec<MirRelationExpr>",
721 self
722 )
723 ));
724 }
725 _ => {}
726 }
727 }
728 }
729 None
730 }
731 }
732 }
733}
734
735#[derive(Debug, Default)]
738struct Scope {
739 objects: BTreeMap<String, (Id, ReprRelationType)>,
740 names: BTreeMap<Id, String>,
741}
742
743impl Scope {
744 fn insert(
745 &mut self,
746 name: &str,
747 typ: ReprRelationType,
748 ) -> (LocalId, Option<(Id, ReprRelationType)>) {
749 let old_val = self.get(name);
750 let id = LocalId::new(u64::cast_from(self.objects.len()));
751 self.set(name, Id::Local(id), typ);
752 (id, old_val)
753 }
754
755 fn set(&mut self, name: &str, id: Id, typ: ReprRelationType) {
756 self.objects.insert(name.to_string(), (id, typ));
757 self.names.insert(id, name.to_string());
758 }
759
760 fn remove(&mut self, name: &str) {
761 self.objects.remove(name);
762 }
763
764 fn get(&self, name: &str) -> Option<(Id, ReprRelationType)> {
765 self.objects.get(name).cloned()
766 }
767
768 fn iter(&self) -> impl Iterator<Item = (&String, &ReprRelationType)> {
769 self.objects.iter().map(|(s, (_, typ))| (s, typ))
770 }
771}