1use std::collections::BTreeMap;
15
16pub use mz_lowertest_derive::MzReflect;
17use mz_ore::result::ResultExt;
18use mz_ore::str::{StrExt, separated};
19use mz_ore::treat_as_equal::TreatAsEqual;
20use proc_macro2::{Delimiter, TokenStream, TokenTree};
21use serde::de::DeserializeOwned;
22use serde_json::Value;
23
24pub trait MzReflect {
30 fn add_to_reflected_type_info(rti: &mut ReflectedTypeInfo);
36}
37
38impl<T: MzReflect> MzReflect for Vec<T> {
39 fn add_to_reflected_type_info(rti: &mut ReflectedTypeInfo) {
40 T::add_to_reflected_type_info(rti);
41 }
42}
43
44impl<T> MzReflect for TreatAsEqual<T> {
45 fn add_to_reflected_type_info(_rti: &mut ReflectedTypeInfo) {}
46}
47
48#[derive(Debug, Default)]
53pub struct ReflectedTypeInfo {
54 pub enum_dict:
55 BTreeMap<&'static str, BTreeMap<&'static str, (Vec<&'static str>, Vec<&'static str>)>>,
56 pub struct_dict: BTreeMap<&'static str, (Vec<&'static str>, Vec<&'static str>)>,
57}
58
59pub fn tokenize(s: &str) -> Result<TokenStream, String> {
65 s.parse::<TokenStream>().map_err_to_string_with_causes()
66}
67
68pub fn unquote(s: &str) -> String {
70 if s.starts_with('"') && s.ends_with('"') {
71 s[1..(s.len() - 1)].replace("\\\"", "\"")
72 } else {
73 s.to_string()
74 }
75}
76
77pub fn deserialize_optional_generic<D, I>(
81 stream_iter: &mut I,
82 type_name: &'static str,
83) -> Result<Option<D>, String>
84where
85 D: DeserializeOwned + MzReflect,
86 I: Iterator<Item = TokenTree>,
87{
88 deserialize_optional(
89 stream_iter,
90 type_name,
91 &mut GenericTestDeserializeContext::default(),
92 )
93}
94
95pub fn deserialize_optional<D, I, C>(
103 stream_iter: &mut I,
104 type_name: &'static str,
105 ctx: &mut C,
106) -> Result<Option<D>, String>
107where
108 C: TestDeserializeContext,
109 D: DeserializeOwned + MzReflect,
110 I: Iterator<Item = TokenTree>,
111{
112 let mut rti = ReflectedTypeInfo::default();
113 D::add_to_reflected_type_info(&mut rti);
114 match to_json(stream_iter, type_name, &rti, ctx)? {
115 Some(j) => Ok(Some(serde_json::from_str::<D>(&j).map_err(|e| {
116 format!("String while serializing: {}\nOriginal JSON: {}", e, j)
117 })?)),
118 None => Ok(None),
119 }
120}
121
122pub fn deserialize_generic<D, I>(stream_iter: &mut I, type_name: &'static str) -> Result<D, String>
124where
125 D: DeserializeOwned + MzReflect,
126 I: Iterator<Item = TokenTree>,
127{
128 deserialize(
129 stream_iter,
130 type_name,
131 &mut GenericTestDeserializeContext::default(),
132 )
133}
134
135pub fn deserialize<D, I, C>(
143 stream_iter: &mut I,
144 type_name: &'static str,
145 ctx: &mut C,
146) -> Result<D, String>
147where
148 C: TestDeserializeContext,
149 D: DeserializeOwned + MzReflect,
150 I: Iterator<Item = TokenTree>,
151{
152 deserialize_optional(stream_iter, type_name, ctx)?
153 .ok_or_else(|| format!("Empty spec for type {}", type_name))
154}
155
156pub fn to_json<I, C>(
184 stream_iter: &mut I,
185 type_name: &str,
186 rti: &ReflectedTypeInfo,
187 ctx: &mut C,
188) -> Result<Option<String>, String>
189where
190 C: TestDeserializeContext,
191 I: Iterator<Item = TokenTree>,
192{
193 let (type_name, option_found) = normalize_type_name(type_name);
194
195 if let Some((_, f_types)) = rti.struct_dict.get(&type_name[..]) {
197 if f_types.is_empty() {
198 return Ok(Some("null".to_string()));
199 }
200 }
201
202 if let Some(first_arg) = stream_iter.next() {
203 if option_found {
206 if let TokenTree::Ident(ident) = &first_arg {
207 if *ident == "null" {
208 return Ok(Some("null".to_string()));
209 }
210 }
211 }
212
213 if let Some(result) =
219 parse_as_enum_or_struct(first_arg.clone(), stream_iter, &type_name, rti, ctx)?
220 {
221 return Ok(Some(result));
222 }
223 if let Some(result) = ctx.override_syntax(first_arg.clone(), stream_iter, &type_name)? {
225 return Ok(Some(result));
226 }
227 match first_arg {
228 TokenTree::Group(group) => {
229 let mut inner_iter = group.stream().into_iter();
230 match group.delimiter() {
231 Delimiter::Bracket => {
232 if type_name.starts_with("Vec<") && type_name.ends_with('>') {
233 let vec = parse_as_vec(
235 &mut inner_iter,
236 &type_name[4..(type_name.len() - 1)],
237 rti,
238 ctx,
239 )?;
240 Ok(Some(format!("[{}]", separated(",", vec.iter()))))
241 } else if type_name.starts_with("[") && type_name.ends_with(']') {
242 let vec = parse_as_vec(
244 &mut inner_iter,
245 &type_name[1..(type_name.len() - 1)],
246 rti,
247 ctx,
248 )?;
249 Ok(Some(format!("[{}]", separated(",", vec.iter()))))
250 } else if type_name.starts_with('(') && type_name.ends_with(')') {
251 let vec = parse_as_tuple(
252 &mut inner_iter,
253 &type_name[1..(type_name.len() - 1)],
254 rti,
255 ctx,
256 )?;
257 Ok(Some(format!("[{}]", separated(",", vec.iter()))))
258 } else {
259 Err(format!(
260 "Object specified with brackets {:?} has unsupported type `{}`",
261 inner_iter.collect::<Vec<_>>(),
262 type_name
263 ))
264 }
265 }
266 delim => Err(format!(
267 "Object spec {:?} (type {}) has unsupported delimiter {:?}",
268 inner_iter.collect::<Vec<_>>(),
269 type_name,
270 delim
271 )),
272 }
273 }
274 TokenTree::Punct(punct) => {
275 match punct.as_char() {
276 ',' => to_json(stream_iter, &type_name, rti, ctx),
279 other => match to_json(stream_iter, &type_name, rti, ctx)? {
283 Some(result) => Ok(Some(format!("{}{}", other, result))),
284 None => Ok(Some(other.to_string())),
285 },
286 }
287 }
288 TokenTree::Ident(ident) => {
289 if type_name == "String" && ident != "null" {
293 Ok(Some(ident.to_string().quoted().to_string()))
294 } else {
295 Ok(Some(ident.to_string()))
296 }
297 }
298 TokenTree::Literal(literal) => Ok(Some(literal.to_string())),
299 }
300 } else {
301 Ok(None)
302 }
303}
304
305pub trait TestDeserializeContext {
313 fn override_syntax<I>(
323 &mut self,
324 first_arg: TokenTree,
325 rest_of_stream: &mut I,
326 type_name: &str,
327 ) -> Result<Option<String>, String>
328 where
329 I: Iterator<Item = TokenTree>;
330
331 fn reverse_syntax_override(&mut self, json: &Value, type_name: &str) -> Option<String>;
337}
338
339#[derive(Default)]
343struct GenericTestDeserializeContext;
344
345impl TestDeserializeContext for GenericTestDeserializeContext {
346 fn override_syntax<I>(
347 &mut self,
348 _first_arg: TokenTree,
349 _rest_of_stream: &mut I,
350 _type_name: &str,
351 ) -> Result<Option<String>, String>
352 where
353 I: Iterator<Item = TokenTree>,
354 {
355 Ok(None)
356 }
357
358 fn reverse_syntax_override(&mut self, _: &Value, _: &str) -> Option<String> {
359 None
360 }
361}
362
363fn parse_as_vec<I, C>(
367 stream_iter: &mut I,
368 type_name: &str,
369 rti: &ReflectedTypeInfo,
370 ctx: &mut C,
371) -> Result<Vec<String>, String>
372where
373 C: TestDeserializeContext,
374 I: Iterator<Item = TokenTree>,
375{
376 let mut result = Vec::new();
377 while let Some(element) = to_json(stream_iter, type_name, rti, ctx)? {
378 result.push(element);
379 }
380 Ok(result)
381}
382
383fn parse_as_tuple<I, C>(
387 stream_iter: &mut I,
388 type_name: &str,
389 rti: &ReflectedTypeInfo,
390 ctx: &mut C,
391) -> Result<Vec<String>, String>
392where
393 C: TestDeserializeContext,
394 I: Iterator<Item = TokenTree>,
395{
396 let mut prev_elem_end = 0;
397 let mut result = Vec::new();
398 while let Some((next_elem_begin, next_elem_end)) =
399 find_next_type_in_tuple(type_name, prev_elem_end)
400 {
401 match to_json(
402 stream_iter,
403 &type_name[next_elem_begin..next_elem_end],
404 rti,
405 ctx,
406 )? {
407 Some(elem) => result.push(elem),
408 None => break,
411 }
412 prev_elem_end = next_elem_end;
413 }
414 Ok(result)
415}
416
417fn parse_as_enum_or_struct<I, C>(
423 first_arg: TokenTree,
424 rest_of_stream: &mut I,
425 type_name: &str,
426 rti: &ReflectedTypeInfo,
427 ctx: &mut C,
428) -> Result<Option<String>, String>
429where
430 C: TestDeserializeContext,
431 I: Iterator<Item = TokenTree>,
432{
433 if rti.enum_dict.contains_key(type_name) || rti.struct_dict.contains_key(type_name) {
434 match first_arg {
439 TokenTree::Group(group) if group.delimiter() == Delimiter::Parenthesis => {
440 let mut inner_iter = group.stream().into_iter();
441 match inner_iter.next() {
442 Some(first_arg) => parse_as_enum_or_struct_inner(
444 first_arg,
445 &mut inner_iter,
446 type_name,
447 rti,
448 ctx,
449 ),
450 None => Ok(None),
451 }
452 }
453 TokenTree::Punct(punct) => {
454 let mut consecutive_punct = Vec::new();
459 while let Some(token) = rest_of_stream.next() {
460 consecutive_punct.push(token);
461 match &consecutive_punct[consecutive_punct.len() - 1] {
462 TokenTree::Punct(_) => {}
463 _ => {
464 break;
465 }
466 }
467 }
468 parse_as_enum_or_struct_inner(
469 TokenTree::Punct(punct),
470 &mut consecutive_punct.into_iter(),
471 type_name,
472 rti,
473 ctx,
474 )
475 }
476 other => {
477 parse_as_enum_or_struct_inner(other, &mut std::iter::empty(), type_name, rti, ctx)
481 }
482 }
483 } else {
484 Ok(None)
485 }
486}
487
488fn parse_as_enum_or_struct_inner<I, C>(
489 first_arg: TokenTree,
490 rest_of_stream: &mut I,
491 type_name: &str,
492 rti: &ReflectedTypeInfo,
493 ctx: &mut C,
494) -> Result<Option<String>, String>
495where
496 C: TestDeserializeContext,
497 I: Iterator<Item = TokenTree>,
498{
499 if let Some(result) = ctx.override_syntax(first_arg.clone(), rest_of_stream, type_name)? {
500 Ok(Some(result))
501 } else if let Some((f_names, f_types)) = rti.struct_dict.get(type_name).map(|r| r.clone()) {
502 Ok(Some(to_json_fields(
503 type_name,
504 &mut (&mut std::iter::once(first_arg)).chain(rest_of_stream),
505 f_names,
506 f_types,
507 rti,
508 ctx,
509 )?))
510 } else if let TokenTree::Ident(ident) = first_arg {
511 Ok(Some(to_json_generic_enum(
512 ident.to_string(),
513 rest_of_stream,
514 type_name,
515 rti,
516 ctx,
517 )?))
518 } else {
519 Ok(None)
520 }
521}
522
523fn to_json_generic_enum<I, C>(
525 variant_snake_case: String,
526 rest_of_stream: &mut I,
527 type_name: &str,
528 rti: &ReflectedTypeInfo,
529 ctx: &mut C,
530) -> Result<String, String>
531where
532 C: TestDeserializeContext,
533 I: Iterator<Item = TokenTree>,
534{
535 let variant_camel_case = variant_snake_case
537 .split('_')
538 .map(|s| {
539 let mut chars = s.chars();
540 let result = chars
541 .next()
542 .map(|c| c.to_uppercase().chain(chars).collect::<String>())
543 .unwrap_or_else(String::new);
544 result
545 })
546 .collect::<Vec<_>>()
547 .concat();
548 let (f_names, f_types) = rti
549 .enum_dict
550 .get(type_name)
551 .unwrap()
552 .get(&variant_camel_case[..])
553 .map(|v| v.clone())
554 .ok_or_else(|| {
555 format!(
556 "{}::{} is not a supported enum.",
557 type_name, variant_camel_case
558 )
559 })?;
560 if f_types.is_empty() {
563 Ok(format!("\"{}\"", variant_camel_case))
565 } else {
566 let fields = to_json_fields(
567 &variant_camel_case,
568 rest_of_stream,
569 f_names,
570 f_types,
571 rti,
572 ctx,
573 )?;
574 Ok(format!("{{\"{}\":{}}}", variant_camel_case, fields))
575 }
576}
577
578fn to_json_fields<I, C>(
584 debug_name: &str,
585 stream_iter: &mut I,
586 f_names: Vec<&'static str>,
587 f_types: Vec<&'static str>,
588 rti: &ReflectedTypeInfo,
589 ctx: &mut C,
590) -> Result<String, String>
591where
592 C: TestDeserializeContext,
593 I: Iterator<Item = TokenTree>,
594{
595 let mut f_values = Vec::new();
596 for t in f_types.iter() {
597 match to_json(stream_iter, t, rti, ctx)? {
598 Some(value) => f_values.push(value),
599 None => {
600 break;
601 }
602 }
603 }
604 if !f_names.is_empty() {
605 Ok(format!(
608 "{{{}}}",
609 separated(
610 ",",
611 f_names
612 .iter()
613 .zip(f_values.into_iter())
614 .map(|(n, v)| format!("\"{}\":{}", n, v))
615 )
616 ))
617 } else {
618 if f_types.len() == 1 {
622 Ok(f_values
623 .pop()
624 .ok_or_else(|| format!("Cannot use default value for {}", debug_name))?)
625 } else {
626 Ok(format!("[{}]", separated(",", f_values.into_iter())))
627 }
628 }
629}
630
631pub fn serialize_generic<M>(json: &Value, type_name: &str) -> String
635where
636 M: MzReflect,
637{
638 let mut rti = ReflectedTypeInfo::default();
639 M::add_to_reflected_type_info(&mut rti);
640 from_json(
641 json,
642 type_name,
643 &rti,
644 &mut GenericTestDeserializeContext::default(),
645 )
646}
647
648pub fn serialize<M, C>(json: &Value, type_name: &str, ctx: &mut C) -> String
649where
650 C: TestDeserializeContext,
651 M: MzReflect,
652{
653 let mut rti = ReflectedTypeInfo::default();
654 M::add_to_reflected_type_info(&mut rti);
655 from_json(json, type_name, &rti, ctx)
656}
657
658pub fn from_json<C>(json: &Value, type_name: &str, rti: &ReflectedTypeInfo, ctx: &mut C) -> String
665where
666 C: TestDeserializeContext,
667{
668 let (type_name, option_found) = normalize_type_name(type_name);
669 if option_found {
673 if let Value::Null = json {
674 return "null".to_string();
675 }
676 }
677 if let Some(result) = ctx.reverse_syntax_override(json, &type_name) {
678 return result;
679 }
680 if let Some((names, types)) = rti.struct_dict.get(&type_name[..]) {
681 if types.is_empty() {
682 "".to_string()
683 } else {
684 format!("({})", from_json_fields(json, names, types, rti, ctx))
685 }
686 } else if let Some(enum_dict) = rti.enum_dict.get(&type_name[..]) {
687 match json {
688 Value::String(s) => unquote(s),
690 Value::Object(map) => {
693 assert_eq!(
695 map.len(),
696 1,
697 "Multivariant instance {:?} found for enum {}",
698 map,
699 type_name
700 );
701 for (variant, data) in map.iter() {
702 if let Some((names, types)) = enum_dict.get(&variant[..]) {
703 return format!(
704 "({} {})",
705 variant,
706 from_json_fields(data, names, types, rti, ctx)
707 );
708 }
709 }
710 unreachable!()
711 }
712 _ => unreachable!("Invalid json {:?} for enum type {}", json, type_name),
713 }
714 } else {
715 match json {
716 Value::Array(members) => {
717 let result = if type_name.starts_with("Vec<") && type_name.ends_with('>') {
718 members
720 .iter()
721 .map(|v| from_json(v, &type_name[4..(type_name.len() - 1)], rti, ctx))
722 .collect::<Vec<_>>()
723 } else {
724 let mut result = Vec::new();
726 let type_name = &type_name[1..(type_name.len() - 1)];
727 let mut prev_elem_end = 0;
728 let mut members_iter = members.into_iter();
729 while let Some((next_elem_begin, next_elem_end)) =
730 find_next_type_in_tuple(type_name, prev_elem_end)
731 {
732 match members_iter.next() {
733 Some(elem) => result.push(from_json(
734 elem,
735 &type_name[next_elem_begin..next_elem_end],
736 rti,
737 ctx,
738 )),
739 None => break,
741 }
742 prev_elem_end = next_elem_end;
743 }
744 result
745 };
746 format!("[{}]", separated(" ", result))
748 }
749 Value::Object(map) => {
750 unreachable!("Invalid map {:?} found for type {}", map, type_name)
751 }
752 other => other.to_string(),
753 }
754 }
755}
756
757fn from_json_fields<C>(
758 v: &Value,
759 f_names: &[&'static str],
760 f_types: &[&'static str],
761 rti: &ReflectedTypeInfo,
762 ctx: &mut C,
763) -> String
764where
765 C: TestDeserializeContext,
766{
767 match v {
768 Value::Object(map) if !f_names.is_empty() => {
774 let mut fields = Vec::with_capacity(f_types.len());
775 for (name, typ) in f_names.iter().zip(f_types.iter()) {
776 fields.push(from_json(&map[*name], typ, rti, ctx))
777 }
778 separated(" ", fields).to_string()
779 }
780 Value::Array(inner) if f_types.len() > 1 => {
783 let mut fields = Vec::with_capacity(f_types.len());
784 for (v, typ) in inner.iter().zip(f_types.iter()) {
785 fields.push(from_json(v, typ, rti, ctx))
786 }
787 separated(" ", fields).to_string()
788 }
789 other => from_json(other, f_types.first().unwrap(), rti, ctx),
791 }
792}
793
794fn normalize_type_name(type_name: &str) -> (String, bool) {
798 let mut type_name = &type_name.replace([' ', '\n'], "")[..];
800 let mut option_found = false;
801 loop {
808 if type_name.starts_with("Option<") && type_name.ends_with('>') {
809 option_found = true;
810 type_name = &type_name[7..(type_name.len() - 1)]
811 } else if type_name.starts_with("Box<") && type_name.ends_with('>') {
812 type_name = &type_name[4..(type_name.len() - 1)]
813 } else {
814 break;
815 }
816 }
817
818 (type_name.to_string(), option_found)
819}
820
821fn find_next_type_in_tuple(type_name: &str, prev_elem_end: usize) -> Option<(usize, usize)> {
822 let current_elem_begin = if prev_elem_end > 0 {
823 prev_elem_end + 1
825 } else {
826 prev_elem_end
827 };
828 if current_elem_begin >= type_name.len() {
829 return None;
830 }
831 let mut i = current_elem_begin;
835 let mut it = type_name.chars().skip(current_elem_begin).peekable();
836 let mut paren_level = 0;
837 let mut bracket_level = 0;
838 while i < type_name.len()
839 && !(paren_level == 0 && bracket_level == 0 && *it.peek().unwrap() == ',')
840 {
841 if *it.peek().unwrap() == '(' {
842 paren_level += 1;
843 } else if *it.peek().unwrap() == ')' {
844 paren_level -= 1;
845 }
846 if *it.peek().unwrap() == '<' {
847 bracket_level += 1;
848 } else if *it.peek().unwrap() == '>' {
849 bracket_level -= 1;
850 }
851 i += 1;
852 it.next();
853 }
854
855 Some((current_elem_begin, i))
856}
857
858