1use std::collections::BTreeMap;
11
12use mz_expr::{AccessStrategy, EvalError, Id, LocalId, MirRelationExpr, MirScalarExpr};
13use mz_lowertest::*;
14use mz_ore::cast::CastFrom;
15use mz_ore::result::ResultExt;
16use mz_ore::str::separated;
17use mz_repr::explain::{DummyHumanizer, ExprHumanizer};
18use mz_repr::{Diff, GlobalId, Row, SqlColumnType, SqlRelationType, SqlScalarType};
19use mz_repr_test_util::*;
20use proc_macro2::TokenTree;
21use serde::{Deserialize, Serialize};
22use serde_json::Value;
23
24pub fn build_scalar(s: &str) -> Result<MirScalarExpr, String> {
28 deserialize(
29 &mut tokenize(s)?.into_iter(),
30 "MirScalarExpr",
31 &mut MirScalarExprDeserializeContext::default(),
32 )
33}
34
35pub fn build_rel(s: &str, catalog: &TestCatalog) -> Result<MirRelationExpr, String> {
39 deserialize(
40 &mut tokenize(s)?.into_iter(),
41 "MirRelationExpr",
42 &mut MirRelationExprDeserializeContext::new(catalog),
43 )
44}
45
46pub fn json_to_spec(rel_json: &str, catalog: &TestCatalog) -> (String, Vec<String>) {
54 let mut ctx = MirRelationExprDeserializeContext::new(catalog);
55 let spec = serialize::<MirRelationExpr, _>(
56 &serde_json::from_str(rel_json).unwrap(),
57 "MirRelationExpr",
58 &mut ctx,
59 );
60 let mut source_defs = ctx
61 .list_scope_references()
62 .map(|(name, typ)| {
63 format!(
64 "(defsource {} {})",
65 name,
66 serialize_generic::<SqlRelationType>(
67 &serde_json::to_value(typ).unwrap(),
68 "SqlRelationType",
69 )
70 )
71 })
72 .collect::<Vec<_>>();
73 source_defs.sort();
74 (spec, source_defs)
75}
76
77#[derive(Debug, Default)]
82pub struct TestCatalog {
83 objects: BTreeMap<String, (GlobalId, SqlRelationType)>,
84 names: BTreeMap<GlobalId, String>,
85}
86
87#[derive(Debug, Serialize, Deserialize, MzReflect)]
91enum TestCatalogCommand {
92 Defsource { name: String, typ: SqlRelationType },
94}
95
96impl<'a> TestCatalog {
97 pub fn insert(
106 &mut self,
107 name: &str,
108 typ: SqlRelationType,
109 transient: bool,
110 ) -> Result<GlobalId, String> {
111 if self.objects.contains_key(name) {
112 return Err(format!("Object {} already exists in catalog", name));
113 }
114 let id = if transient {
115 GlobalId::Transient(u64::cast_from(self.objects.len()))
116 } else {
117 GlobalId::User(u64::cast_from(self.objects.len()))
118 };
119 self.objects.insert(name.to_string(), (id, typ));
120 self.names.insert(id, name.to_string());
121 Ok(id)
122 }
123
124 fn get(&'a self, name: &str) -> Option<&'a (GlobalId, SqlRelationType)> {
125 self.objects.get(name)
126 }
127
128 pub fn get_source_name(&'a self, id: &GlobalId) -> Option<&'a String> {
130 self.names.get(id)
131 }
132
133 pub fn handle_test_command(&mut self, spec: &str) -> Result<(), String> {
139 let mut stream_iter = tokenize(spec)?.into_iter();
140 while let Some(command) = deserialize_optional_generic::<TestCatalogCommand, _>(
141 &mut stream_iter,
142 "TestCatalogCommand",
143 )? {
144 match command {
145 TestCatalogCommand::Defsource { name, typ } => {
146 self.insert(&name, typ, false)?;
147 }
148 }
149 }
150 Ok(())
151 }
152
153 pub fn remove_transient_objects(&mut self) {
155 self.objects.retain(|_, (id, _)| {
156 if let GlobalId::Transient(_) = id {
157 false
158 } else {
159 true
160 }
161 });
162 self.names.retain(|k, _| {
163 if let GlobalId::Transient(_) = k {
164 false
165 } else {
166 true
167 }
168 });
169 }
170}
171
172impl ExprHumanizer for TestCatalog {
173 fn humanize_id(&self, id: GlobalId) -> Option<String> {
174 self.names.get(&id).map(|s| s.to_string())
175 }
176
177 fn humanize_id_unqualified(&self, id: GlobalId) -> Option<String> {
178 self.names.get(&id).map(|s| s.to_string())
179 }
180
181 fn humanize_id_parts(&self, id: GlobalId) -> Option<Vec<String>> {
182 self.humanize_id_unqualified(id).map(|name| vec![name])
183 }
184
185 fn humanize_scalar_type(&self, ty: &SqlScalarType, postgres_compat: bool) -> String {
186 DummyHumanizer.humanize_scalar_type(ty, postgres_compat)
187 }
188
189 fn column_names_for_id(&self, _id: GlobalId) -> Option<Vec<String>> {
190 None
191 }
192
193 fn humanize_column(&self, _id: GlobalId, _column: usize) -> Option<String> {
194 None
195 }
196
197 fn id_exists(&self, id: GlobalId) -> bool {
198 self.names.contains_key(&id)
199 }
200}
201
202#[derive(Default)]
215pub struct MirScalarExprDeserializeContext;
216
217impl MirScalarExprDeserializeContext {
218 fn build_column(&self, token: Option<TokenTree>) -> Result<MirScalarExpr, String> {
219 if let Some(TokenTree::Literal(literal)) = token {
220 return Ok(MirScalarExpr::column(
221 literal
222 .to_string()
223 .parse::<usize>()
224 .map_err_to_string_with_causes()?,
225 ));
226 }
227 Err(format!(
228 "Invalid column specification {:?}",
229 token.map(|token_tree| format!("`{}`", token_tree))
230 ))
231 }
232
233 fn build_literal_if_able<I>(
234 &self,
235 first_arg: TokenTree,
236 rest_of_stream: &mut I,
237 ) -> Result<Option<MirScalarExpr>, String>
238 where
239 I: Iterator<Item = TokenTree>,
240 {
241 match &first_arg {
242 TokenTree::Ident(i) if i.to_string().eq_ignore_ascii_case("ok") => {
243 let first_arg = if let Some(first_arg) = rest_of_stream.next() {
245 first_arg
246 } else {
247 return Err(format!("expected literal after Ident: `{}`", i));
248 };
249 match self.build_literal_ok_if_able(first_arg, rest_of_stream) {
250 Ok(Some(l)) => Ok(Some(l)),
251 _ => Err(format!("expected literal after Ident: `{}`", i)),
252 }
253 }
254 TokenTree::Ident(i) if i.to_string().eq_ignore_ascii_case("err") => {
255 let error = deserialize_generic(rest_of_stream, "EvalError")?;
256 let typ: Option<SqlScalarType> =
257 deserialize_optional_generic(rest_of_stream, "SqlScalarType")?;
258 Ok(Some(MirScalarExpr::literal(
259 Err(error),
260 typ.unwrap_or(SqlScalarType::Bool),
261 )))
262 }
263 _ => self.build_literal_ok_if_able(first_arg, rest_of_stream),
264 }
265 }
266
267 fn build_literal_ok_if_able<I>(
268 &self,
269 first_arg: TokenTree,
270 rest_of_stream: &mut I,
271 ) -> Result<Option<MirScalarExpr>, String>
272 where
273 I: Iterator<Item = TokenTree>,
274 {
275 match extract_literal_string(&first_arg, rest_of_stream)? {
276 Some(litval) => {
277 let littyp = get_scalar_type_or_default(&litval[..], rest_of_stream)?;
278 Ok(Some(MirScalarExpr::Literal(
279 Ok(test_spec_to_row(std::iter::once((&litval[..], &littyp)))?),
280 littyp.nullable(matches!(&litval[..], "null")),
281 )))
282 }
283 None => Ok(None),
284 }
285 }
286}
287
288impl TestDeserializeContext for MirScalarExprDeserializeContext {
289 fn override_syntax<I>(
290 &mut self,
291 first_arg: TokenTree,
292 rest_of_stream: &mut I,
293 type_name: &str,
294 ) -> Result<Option<String>, String>
295 where
296 I: Iterator<Item = TokenTree>,
297 {
298 let result = if type_name == "MirScalarExpr" {
299 match first_arg {
300 TokenTree::Punct(punct) if punct.as_char() == '#' => {
301 Some(self.build_column(rest_of_stream.next())?)
302 }
303 TokenTree::Group(_) => None,
304 symbol => self.build_literal_if_able(symbol, rest_of_stream)?,
305 }
306 } else {
307 None
308 };
309 match result {
310 Some(result) => Ok(Some(
311 serde_json::to_string(&result).map_err_to_string_with_causes()?,
312 )),
313 None => Ok(None),
314 }
315 }
316
317 fn reverse_syntax_override(&mut self, json: &Value, type_name: &str) -> Option<String> {
318 if type_name == "MirScalarExpr" {
319 let map = json.as_object().unwrap();
320 assert_eq!(map.len(), 1);
322 for (variant, data) in map.iter() {
323 match &variant[..] {
324 "Column" => {
325 return Some(format!(
326 "#{}",
327 data.as_array().unwrap()[0].as_u64().unwrap()
328 ));
329 }
330 "Literal" => {
331 let column_type: SqlColumnType =
332 serde_json::from_value(data.as_array().unwrap()[1].clone()).unwrap();
333 let obj = data.as_array().unwrap()[0].as_object().unwrap();
334 if let Some(inner_data) = obj.get("Ok") {
335 let row: Row = serde_json::from_value(inner_data.clone()).unwrap();
336 let result = format!(
337 "({} {})",
338 datum_to_test_spec(row.unpack_first()),
339 serialize::<SqlScalarType, _>(
340 &serde_json::to_value(&column_type.scalar_type).unwrap(),
341 "SqlScalarType",
342 self
343 )
344 );
345 return Some(result);
346 } else if let Some(inner_data) = obj.get("Err") {
347 let result = format!(
348 "(err {} {})",
349 serialize::<EvalError, _>(inner_data, "EvalError", self),
350 serialize::<SqlScalarType, _>(
351 &serde_json::to_value(&column_type.scalar_type).unwrap(),
352 "SqlScalarType",
353 self
354 ),
355 );
356 return Some(result);
357 } else {
358 unreachable!("unexpected JSON data: {:?}", obj);
359 }
360 }
361 _ => {}
362 }
363 }
364 }
365 None
366 }
367}
368
369pub struct MirRelationExprDeserializeContext<'a> {
397 inner_ctx: MirScalarExprDeserializeContext,
398 catalog: &'a TestCatalog,
399 scope: Scope,
403}
404
405impl<'a> MirRelationExprDeserializeContext<'a> {
406 pub fn new(catalog: &'a TestCatalog) -> Self {
407 Self {
408 inner_ctx: MirScalarExprDeserializeContext::default(),
409 catalog,
410 scope: Scope::default(),
411 }
412 }
413
414 pub fn list_scope_references(&self) -> impl Iterator<Item = (&String, &SqlRelationType)> {
415 self.scope.iter()
416 }
417
418 fn build_constant<I>(&mut self, stream_iter: &mut I) -> Result<MirRelationExpr, String>
419 where
420 I: Iterator<Item = TokenTree>,
421 {
422 let raw_rows = stream_iter
423 .next()
424 .ok_or_else(|| "Constant is empty".to_string())?;
425 let typ: SqlRelationType = deserialize(stream_iter, "SqlRelationType", self)?;
429
430 let mut rows = Vec::new();
431 match raw_rows {
432 TokenTree::Group(group) => {
433 let mut inner_iter = group.stream().into_iter();
434 while let Some(token) = inner_iter.next() {
435 let row = test_spec_to_row(
436 parse_vec_of_literals(&token)?
437 .iter()
438 .zip(&typ.column_types)
439 .map(|(dat, col_typ)| (&dat[..], &col_typ.scalar_type)),
440 )?;
441 rows.push((row, Diff::ONE));
442 }
443 }
444 invalid => return Err(format!("invalid rows spec for constant `{}`", invalid)),
445 };
446 Ok(MirRelationExpr::Constant {
447 rows: Ok(rows),
448 typ,
449 })
450 }
451
452 fn build_constant_err<I>(&mut self, stream_iter: &mut I) -> Result<MirRelationExpr, String>
453 where
454 I: Iterator<Item = TokenTree>,
455 {
456 let error: EvalError = deserialize(stream_iter, "EvalError", self)?;
457 let typ: SqlRelationType = deserialize(stream_iter, "SqlRelationType", self)?;
458
459 Ok(MirRelationExpr::Constant {
460 rows: Err(error),
461 typ,
462 })
463 }
464
465 fn build_get(&self, token: Option<TokenTree>) -> Result<MirRelationExpr, String> {
466 match token {
467 Some(TokenTree::Ident(ident)) => {
468 let name = ident.to_string();
469 match self.scope.get(&name) {
470 Some((id, typ)) => Ok(MirRelationExpr::Get {
471 id,
472 typ,
473 access_strategy: AccessStrategy::UnknownOrLocal,
474 }),
475 None => match self.catalog.get(&name) {
476 None => Err(format!("no catalog object named {}", name)),
477 Some((id, typ)) => Ok(MirRelationExpr::Get {
478 id: Id::Global(*id),
479 typ: typ.clone(),
480 access_strategy: AccessStrategy::UnknownOrLocal,
481 }),
482 },
483 }
484 }
485 invalid_token => Err(format!(
486 "Invalid get specification {:?}",
487 invalid_token.map(|token_tree| format!("`{}`", token_tree))
488 )),
489 }
490 }
491
492 fn build_let<I>(&mut self, stream_iter: &mut I) -> Result<MirRelationExpr, String>
493 where
494 I: Iterator<Item = TokenTree>,
495 {
496 let name = match stream_iter.next() {
497 Some(TokenTree::Ident(ident)) => Ok(ident.to_string()),
498 invalid_token => Err(format!(
499 "Invalid let specification {:?}",
500 invalid_token.map(|token_tree| format!("`{}`", token_tree))
501 )),
502 }?;
503
504 let value: MirRelationExpr = deserialize(stream_iter, "MirRelationExpr", self)?;
505
506 let (id, prev) = self.scope.insert(&name, value.typ());
507
508 let body: MirRelationExpr = deserialize(stream_iter, "MirRelationExpr", self)?;
509
510 if let Some((old_id, old_val)) = prev {
511 self.scope.set(&name, old_id, old_val);
512 } else {
513 self.scope.remove(&name)
514 }
515
516 Ok(MirRelationExpr::Let {
517 id,
518 value: Box::new(value),
519 body: Box::new(body),
520 })
521 }
522
523 fn build_union<I>(&mut self, stream_iter: &mut I) -> Result<MirRelationExpr, String>
524 where
525 I: Iterator<Item = TokenTree>,
526 {
527 let mut inputs: Vec<MirRelationExpr> =
528 deserialize(stream_iter, "Vec<MirRelationExpr>", self)?;
529 Ok(MirRelationExpr::Union {
530 base: Box::new(inputs.remove(0)),
531 inputs,
532 })
533 }
534
535 fn build_special_mir_if_able<I>(
536 &mut self,
537 first_arg: TokenTree,
538 rest_of_stream: &mut I,
539 ) -> Result<Option<MirRelationExpr>, String>
540 where
541 I: Iterator<Item = TokenTree>,
542 {
543 if let TokenTree::Ident(ident) = first_arg {
544 return Ok(match &ident.to_string().to_lowercase()[..] {
545 "constant" => Some(self.build_constant(rest_of_stream)?),
546 "constant_err" => Some(self.build_constant_err(rest_of_stream)?),
547 "get" => Some(self.build_get(rest_of_stream.next())?),
548 "let" => Some(self.build_let(rest_of_stream)?),
549 "union" => Some(self.build_union(rest_of_stream)?),
550 _ => None,
551 });
552 }
553 Ok(None)
554 }
555}
556
557impl<'a> TestDeserializeContext for MirRelationExprDeserializeContext<'a> {
558 fn override_syntax<I>(
559 &mut self,
560 first_arg: TokenTree,
561 rest_of_stream: &mut I,
562 type_name: &str,
563 ) -> Result<Option<String>, String>
564 where
565 I: Iterator<Item = TokenTree>,
566 {
567 match self
568 .inner_ctx
569 .override_syntax(first_arg.clone(), rest_of_stream, type_name)?
570 {
571 Some(result) => Ok(Some(result)),
572 None => {
573 if type_name == "MirRelationExpr" {
574 if let Some(result) =
575 self.build_special_mir_if_able(first_arg, rest_of_stream)?
576 {
577 return Ok(Some(
578 serde_json::to_string(&result).map_err_to_string_with_causes()?,
579 ));
580 }
581 } else if type_name == "usize" {
582 if let TokenTree::Punct(punct) = first_arg {
583 if punct.as_char() == '#' {
584 match rest_of_stream.next() {
585 Some(TokenTree::Literal(literal)) => {
586 return Ok(Some(literal.to_string()));
587 }
588 invalid => {
589 return Err(format!(
590 "invalid column value {:?}",
591 invalid.map(|token_tree| format!("`{}`", token_tree))
592 ));
593 }
594 }
595 }
596 }
597 }
598 Ok(None)
599 }
600 }
601 }
602
603 fn reverse_syntax_override(&mut self, json: &Value, type_name: &str) -> Option<String> {
604 match self.inner_ctx.reverse_syntax_override(json, type_name) {
605 Some(result) => Some(result),
606 None => {
607 if type_name == "MirRelationExpr" {
608 let map = json.as_object().unwrap();
609 assert_eq!(
611 map.len(),
612 1,
613 "Multivariant instance {:?} found for MirRelationExpr",
614 map
615 );
616 for (variant, data) in map.iter() {
617 let inner_map = data.as_object().unwrap();
618 match &variant[..] {
619 "Let" => {
620 let id: LocalId =
621 serde_json::from_value(inner_map["id"].clone()).unwrap();
622 return Some(format!(
623 "(let {} {} {})",
624 id,
625 serialize::<MirRelationExpr, _>(
626 &inner_map["value"],
627 "MirRelationExpr",
628 self
629 ),
630 serialize::<MirRelationExpr, _>(
631 &inner_map["body"],
632 "MirRelationExpr",
633 self
634 ),
635 ));
636 }
637 "Get" => {
638 let id: Id =
639 serde_json::from_value(inner_map["id"].clone()).unwrap();
640 return Some(match id {
641 Id::Global(global) => {
642 match self.catalog.get_source_name(&global) {
643 Some(source) => format!("(get {})", source),
646 None => {
648 let typ: SqlRelationType = serde_json::from_value(
649 inner_map["typ"].clone(),
650 )
651 .unwrap();
652 self.scope.insert(&id.to_string(), typ);
653 format!("(get {})", id)
654 }
655 }
656 }
657 _ => {
658 format!("(get {})", id)
659 }
660 });
661 }
662 "Constant" => {
663 if let Some(row_vec) = inner_map["rows"].get("Ok") {
664 let mut rows = Vec::new();
665 for inner_array in row_vec.as_array().unwrap() {
666 let row: Row =
667 serde_json::from_value(inner_array[0].clone()).unwrap();
668 let diff = inner_array[1].as_u64().unwrap();
669 for _ in 0..diff {
670 rows.push(format!(
671 "[{}]",
672 separated(" ", row.iter().map(datum_to_test_spec))
673 ))
674 }
675 }
676 return Some(format!(
677 "(constant [{}] {})",
678 separated(" ", rows),
679 serialize::<SqlRelationType, _>(
680 &inner_map["typ"],
681 "SqlRelationType",
682 self
683 )
684 ));
685 } else if let Some(inner_data) = inner_map["rows"].get("Err") {
686 return Some(format!(
687 "(constant_err {} {})",
688 serialize::<EvalError, _>(inner_data, "EvalError", self),
689 serialize::<SqlRelationType, _>(
690 &inner_map["typ"],
691 "SqlRelationType",
692 self
693 )
694 ));
695 } else {
696 unreachable!("unexpected JSON data: {:?}", inner_map);
697 }
698 }
699 "Union" => {
700 let mut inputs = inner_map["inputs"].as_array().unwrap().to_owned();
701 inputs.insert(0, inner_map["base"].clone());
702 return Some(format!(
703 "(union {})",
704 serialize::<Vec<MirRelationExpr>, _>(
705 &Value::Array(inputs),
706 "Vec<MirRelationExpr>",
707 self
708 )
709 ));
710 }
711 _ => {}
712 }
713 }
714 }
715 None
716 }
717 }
718 }
719}
720
721#[derive(Debug, Default)]
724struct Scope {
725 objects: BTreeMap<String, (Id, SqlRelationType)>,
726 names: BTreeMap<Id, String>,
727}
728
729impl Scope {
730 fn insert(
731 &mut self,
732 name: &str,
733 typ: SqlRelationType,
734 ) -> (LocalId, Option<(Id, SqlRelationType)>) {
735 let old_val = self.get(name);
736 let id = LocalId::new(u64::cast_from(self.objects.len()));
737 self.set(name, Id::Local(id), typ);
738 (id, old_val)
739 }
740
741 fn set(&mut self, name: &str, id: Id, typ: SqlRelationType) {
742 self.objects.insert(name.to_string(), (id, typ));
743 self.names.insert(id, name.to_string());
744 }
745
746 fn remove(&mut self, name: &str) {
747 self.objects.remove(name);
748 }
749
750 fn get(&self, name: &str) -> Option<(Id, SqlRelationType)> {
751 self.objects.get(name).cloned()
752 }
753
754 fn iter(&self) -> impl Iterator<Item = (&String, &SqlRelationType)> {
755 self.objects.iter().map(|(s, (_, typ))| (s, typ))
756 }
757}