1use std::collections::BTreeMap;
15
16pub use mz_lowertest_derive::MzReflect;
17use mz_ore::result::ResultExt;
18use mz_ore::str::{StrExt, separated};
19use proc_macro2::{Delimiter, TokenStream, TokenTree};
20use serde::de::DeserializeOwned;
21use serde_json::Value;
22
23pub trait MzReflect {
29 fn add_to_reflected_type_info(rti: &mut ReflectedTypeInfo);
35}
36
37impl<T: MzReflect> MzReflect for Vec<T> {
38 fn add_to_reflected_type_info(rti: &mut ReflectedTypeInfo) {
39 T::add_to_reflected_type_info(rti);
40 }
41}
42
43#[derive(Debug, Default)]
48pub struct ReflectedTypeInfo {
49 pub enum_dict:
50 BTreeMap<&'static str, BTreeMap<&'static str, (Vec<&'static str>, Vec<&'static str>)>>,
51 pub struct_dict: BTreeMap<&'static str, (Vec<&'static str>, Vec<&'static str>)>,
52}
53
54pub fn tokenize(s: &str) -> Result<TokenStream, String> {
60 s.parse::<TokenStream>().map_err_to_string_with_causes()
61}
62
63pub fn unquote(s: &str) -> String {
65 if s.starts_with('"') && s.ends_with('"') {
66 s[1..(s.len() - 1)].replace("\\\"", "\"")
67 } else {
68 s.to_string()
69 }
70}
71
72pub fn deserialize_optional_generic<D, I>(
76 stream_iter: &mut I,
77 type_name: &'static str,
78) -> Result<Option<D>, String>
79where
80 D: DeserializeOwned + MzReflect,
81 I: Iterator<Item = TokenTree>,
82{
83 deserialize_optional(
84 stream_iter,
85 type_name,
86 &mut GenericTestDeserializeContext::default(),
87 )
88}
89
90pub fn deserialize_optional<D, I, C>(
98 stream_iter: &mut I,
99 type_name: &'static str,
100 ctx: &mut C,
101) -> Result<Option<D>, String>
102where
103 C: TestDeserializeContext,
104 D: DeserializeOwned + MzReflect,
105 I: Iterator<Item = TokenTree>,
106{
107 let mut rti = ReflectedTypeInfo::default();
108 D::add_to_reflected_type_info(&mut rti);
109 match to_json(stream_iter, type_name, &rti, ctx)? {
110 Some(j) => Ok(Some(serde_json::from_str::<D>(&j).map_err(|e| {
111 format!("String while serializing: {}\nOriginal JSON: {}", e, j)
112 })?)),
113 None => Ok(None),
114 }
115}
116
117pub fn deserialize_generic<D, I>(stream_iter: &mut I, type_name: &'static str) -> Result<D, String>
119where
120 D: DeserializeOwned + MzReflect,
121 I: Iterator<Item = TokenTree>,
122{
123 deserialize(
124 stream_iter,
125 type_name,
126 &mut GenericTestDeserializeContext::default(),
127 )
128}
129
130pub fn deserialize<D, I, C>(
138 stream_iter: &mut I,
139 type_name: &'static str,
140 ctx: &mut C,
141) -> Result<D, String>
142where
143 C: TestDeserializeContext,
144 D: DeserializeOwned + MzReflect,
145 I: Iterator<Item = TokenTree>,
146{
147 deserialize_optional(stream_iter, type_name, ctx)?
148 .ok_or_else(|| format!("Empty spec for type {}", type_name))
149}
150
151pub fn to_json<I, C>(
179 stream_iter: &mut I,
180 type_name: &str,
181 rti: &ReflectedTypeInfo,
182 ctx: &mut C,
183) -> Result<Option<String>, String>
184where
185 C: TestDeserializeContext,
186 I: Iterator<Item = TokenTree>,
187{
188 let (type_name, option_found) = normalize_type_name(type_name);
189
190 if let Some((_, f_types)) = rti.struct_dict.get(&type_name[..]) {
192 if f_types.is_empty() {
193 return Ok(Some("null".to_string()));
194 }
195 }
196
197 if let Some(first_arg) = stream_iter.next() {
198 if option_found {
201 if let TokenTree::Ident(ident) = &first_arg {
202 if *ident == "null" {
203 return Ok(Some("null".to_string()));
204 }
205 }
206 }
207
208 if let Some(result) =
214 parse_as_enum_or_struct(first_arg.clone(), stream_iter, &type_name, rti, ctx)?
215 {
216 return Ok(Some(result));
217 }
218 if let Some(result) = ctx.override_syntax(first_arg.clone(), stream_iter, &type_name)? {
220 return Ok(Some(result));
221 }
222 match first_arg {
223 TokenTree::Group(group) => {
224 let mut inner_iter = group.stream().into_iter();
225 match group.delimiter() {
226 Delimiter::Bracket => {
227 if type_name.starts_with("Vec<") && type_name.ends_with('>') {
228 let vec = parse_as_vec(
230 &mut inner_iter,
231 &type_name[4..(type_name.len() - 1)],
232 rti,
233 ctx,
234 )?;
235 Ok(Some(format!("[{}]", separated(",", vec.iter()))))
236 } else if type_name.starts_with("[") && type_name.ends_with(']') {
237 let vec = parse_as_vec(
239 &mut inner_iter,
240 &type_name[1..(type_name.len() - 1)],
241 rti,
242 ctx,
243 )?;
244 Ok(Some(format!("[{}]", separated(",", vec.iter()))))
245 } else if type_name.starts_with('(') && type_name.ends_with(')') {
246 let vec = parse_as_tuple(
247 &mut inner_iter,
248 &type_name[1..(type_name.len() - 1)],
249 rti,
250 ctx,
251 )?;
252 Ok(Some(format!("[{}]", separated(",", vec.iter()))))
253 } else {
254 Err(format!(
255 "Object specified with brackets {:?} has unsupported type `{}`",
256 inner_iter.collect::<Vec<_>>(),
257 type_name
258 ))
259 }
260 }
261 delim => Err(format!(
262 "Object spec {:?} (type {}) has unsupported delimiter {:?}",
263 inner_iter.collect::<Vec<_>>(),
264 type_name,
265 delim
266 )),
267 }
268 }
269 TokenTree::Punct(punct) => {
270 match punct.as_char() {
271 ',' => to_json(stream_iter, &type_name, rti, ctx),
274 other => match to_json(stream_iter, &type_name, rti, ctx)? {
278 Some(result) => Ok(Some(format!("{}{}", other, result))),
279 None => Ok(Some(other.to_string())),
280 },
281 }
282 }
283 TokenTree::Ident(ident) => {
284 if type_name == "String" && ident != "null" {
288 Ok(Some(ident.to_string().quoted().to_string()))
289 } else {
290 Ok(Some(ident.to_string()))
291 }
292 }
293 TokenTree::Literal(literal) => Ok(Some(literal.to_string())),
294 }
295 } else {
296 Ok(None)
297 }
298}
299
300pub trait TestDeserializeContext {
308 fn override_syntax<I>(
318 &mut self,
319 first_arg: TokenTree,
320 rest_of_stream: &mut I,
321 type_name: &str,
322 ) -> Result<Option<String>, String>
323 where
324 I: Iterator<Item = TokenTree>;
325
326 fn reverse_syntax_override(&mut self, json: &Value, type_name: &str) -> Option<String>;
332}
333
334#[derive(Default)]
338struct GenericTestDeserializeContext;
339
340impl TestDeserializeContext for GenericTestDeserializeContext {
341 fn override_syntax<I>(
342 &mut self,
343 _first_arg: TokenTree,
344 _rest_of_stream: &mut I,
345 _type_name: &str,
346 ) -> Result<Option<String>, String>
347 where
348 I: Iterator<Item = TokenTree>,
349 {
350 Ok(None)
351 }
352
353 fn reverse_syntax_override(&mut self, _: &Value, _: &str) -> Option<String> {
354 None
355 }
356}
357
358fn parse_as_vec<I, C>(
362 stream_iter: &mut I,
363 type_name: &str,
364 rti: &ReflectedTypeInfo,
365 ctx: &mut C,
366) -> Result<Vec<String>, String>
367where
368 C: TestDeserializeContext,
369 I: Iterator<Item = TokenTree>,
370{
371 let mut result = Vec::new();
372 while let Some(element) = to_json(stream_iter, type_name, rti, ctx)? {
373 result.push(element);
374 }
375 Ok(result)
376}
377
378fn parse_as_tuple<I, C>(
382 stream_iter: &mut I,
383 type_name: &str,
384 rti: &ReflectedTypeInfo,
385 ctx: &mut C,
386) -> Result<Vec<String>, String>
387where
388 C: TestDeserializeContext,
389 I: Iterator<Item = TokenTree>,
390{
391 let mut prev_elem_end = 0;
392 let mut result = Vec::new();
393 while let Some((next_elem_begin, next_elem_end)) =
394 find_next_type_in_tuple(type_name, prev_elem_end)
395 {
396 match to_json(
397 stream_iter,
398 &type_name[next_elem_begin..next_elem_end],
399 rti,
400 ctx,
401 )? {
402 Some(elem) => result.push(elem),
403 None => break,
406 }
407 prev_elem_end = next_elem_end;
408 }
409 Ok(result)
410}
411
412fn parse_as_enum_or_struct<I, C>(
418 first_arg: TokenTree,
419 rest_of_stream: &mut I,
420 type_name: &str,
421 rti: &ReflectedTypeInfo,
422 ctx: &mut C,
423) -> Result<Option<String>, String>
424where
425 C: TestDeserializeContext,
426 I: Iterator<Item = TokenTree>,
427{
428 if rti.enum_dict.contains_key(type_name) || rti.struct_dict.contains_key(type_name) {
429 match first_arg {
434 TokenTree::Group(group) if group.delimiter() == Delimiter::Parenthesis => {
435 let mut inner_iter = group.stream().into_iter();
436 match inner_iter.next() {
437 Some(first_arg) => parse_as_enum_or_struct_inner(
439 first_arg,
440 &mut inner_iter,
441 type_name,
442 rti,
443 ctx,
444 ),
445 None => Ok(None),
446 }
447 }
448 TokenTree::Punct(punct) => {
449 let mut consecutive_punct = Vec::new();
454 while let Some(token) = rest_of_stream.next() {
455 consecutive_punct.push(token);
456 match &consecutive_punct[consecutive_punct.len() - 1] {
457 TokenTree::Punct(_) => {}
458 _ => {
459 break;
460 }
461 }
462 }
463 parse_as_enum_or_struct_inner(
464 TokenTree::Punct(punct),
465 &mut consecutive_punct.into_iter(),
466 type_name,
467 rti,
468 ctx,
469 )
470 }
471 other => {
472 parse_as_enum_or_struct_inner(other, &mut std::iter::empty(), type_name, rti, ctx)
476 }
477 }
478 } else {
479 Ok(None)
480 }
481}
482
483fn parse_as_enum_or_struct_inner<I, C>(
484 first_arg: TokenTree,
485 rest_of_stream: &mut I,
486 type_name: &str,
487 rti: &ReflectedTypeInfo,
488 ctx: &mut C,
489) -> Result<Option<String>, String>
490where
491 C: TestDeserializeContext,
492 I: Iterator<Item = TokenTree>,
493{
494 if let Some(result) = ctx.override_syntax(first_arg.clone(), rest_of_stream, type_name)? {
495 Ok(Some(result))
496 } else if let Some((f_names, f_types)) = rti.struct_dict.get(type_name).map(|r| r.clone()) {
497 Ok(Some(to_json_fields(
498 type_name,
499 &mut (&mut std::iter::once(first_arg)).chain(rest_of_stream),
500 f_names,
501 f_types,
502 rti,
503 ctx,
504 )?))
505 } else if let TokenTree::Ident(ident) = first_arg {
506 Ok(Some(to_json_generic_enum(
507 ident.to_string(),
508 rest_of_stream,
509 type_name,
510 rti,
511 ctx,
512 )?))
513 } else {
514 Ok(None)
515 }
516}
517
518fn to_json_generic_enum<I, C>(
520 variant_snake_case: String,
521 rest_of_stream: &mut I,
522 type_name: &str,
523 rti: &ReflectedTypeInfo,
524 ctx: &mut C,
525) -> Result<String, String>
526where
527 C: TestDeserializeContext,
528 I: Iterator<Item = TokenTree>,
529{
530 let variant_camel_case = variant_snake_case
532 .split('_')
533 .map(|s| {
534 let mut chars = s.chars();
535 let result = chars
536 .next()
537 .map(|c| c.to_uppercase().chain(chars).collect::<String>())
538 .unwrap_or_else(String::new);
539 result
540 })
541 .collect::<Vec<_>>()
542 .concat();
543 let (f_names, f_types) = rti
544 .enum_dict
545 .get(type_name)
546 .unwrap()
547 .get(&variant_camel_case[..])
548 .map(|v| v.clone())
549 .ok_or_else(|| {
550 format!(
551 "{}::{} is not a supported enum.",
552 type_name, variant_camel_case
553 )
554 })?;
555 if f_types.is_empty() {
558 Ok(format!("\"{}\"", variant_camel_case))
560 } else {
561 let fields = to_json_fields(
562 &variant_camel_case,
563 rest_of_stream,
564 f_names,
565 f_types,
566 rti,
567 ctx,
568 )?;
569 Ok(format!("{{\"{}\":{}}}", variant_camel_case, fields))
570 }
571}
572
573fn to_json_fields<I, C>(
579 debug_name: &str,
580 stream_iter: &mut I,
581 f_names: Vec<&'static str>,
582 f_types: Vec<&'static str>,
583 rti: &ReflectedTypeInfo,
584 ctx: &mut C,
585) -> Result<String, String>
586where
587 C: TestDeserializeContext,
588 I: Iterator<Item = TokenTree>,
589{
590 let mut f_values = Vec::new();
591 for t in f_types.iter() {
592 match to_json(stream_iter, t, rti, ctx)? {
593 Some(value) => f_values.push(value),
594 None => {
595 break;
596 }
597 }
598 }
599 if !f_names.is_empty() {
600 Ok(format!(
603 "{{{}}}",
604 separated(
605 ",",
606 f_names
607 .iter()
608 .zip(f_values.into_iter())
609 .map(|(n, v)| format!("\"{}\":{}", n, v))
610 )
611 ))
612 } else {
613 if f_types.len() == 1 {
617 Ok(f_values
618 .pop()
619 .ok_or_else(|| format!("Cannot use default value for {}", debug_name))?)
620 } else {
621 Ok(format!("[{}]", separated(",", f_values.into_iter())))
622 }
623 }
624}
625
626pub fn serialize_generic<M>(json: &Value, type_name: &str) -> String
630where
631 M: MzReflect,
632{
633 let mut rti = ReflectedTypeInfo::default();
634 M::add_to_reflected_type_info(&mut rti);
635 from_json(
636 json,
637 type_name,
638 &rti,
639 &mut GenericTestDeserializeContext::default(),
640 )
641}
642
643pub fn serialize<M, C>(json: &Value, type_name: &str, ctx: &mut C) -> String
644where
645 C: TestDeserializeContext,
646 M: MzReflect,
647{
648 let mut rti = ReflectedTypeInfo::default();
649 M::add_to_reflected_type_info(&mut rti);
650 from_json(json, type_name, &rti, ctx)
651}
652
653pub fn from_json<C>(json: &Value, type_name: &str, rti: &ReflectedTypeInfo, ctx: &mut C) -> String
660where
661 C: TestDeserializeContext,
662{
663 let (type_name, option_found) = normalize_type_name(type_name);
664 if option_found {
668 if let Value::Null = json {
669 return "null".to_string();
670 }
671 }
672 if let Some(result) = ctx.reverse_syntax_override(json, &type_name) {
673 return result;
674 }
675 if let Some((names, types)) = rti.struct_dict.get(&type_name[..]) {
676 if types.is_empty() {
677 "".to_string()
678 } else {
679 format!("({})", from_json_fields(json, names, types, rti, ctx))
680 }
681 } else if let Some(enum_dict) = rti.enum_dict.get(&type_name[..]) {
682 match json {
683 Value::String(s) => unquote(s),
685 Value::Object(map) => {
688 assert_eq!(
690 map.len(),
691 1,
692 "Multivariant instance {:?} found for enum {}",
693 map,
694 type_name
695 );
696 for (variant, data) in map.iter() {
697 if let Some((names, types)) = enum_dict.get(&variant[..]) {
698 return format!(
699 "({} {})",
700 variant,
701 from_json_fields(data, names, types, rti, ctx)
702 );
703 }
704 }
705 unreachable!()
706 }
707 _ => unreachable!("Invalid json {:?} for enum type {}", json, type_name),
708 }
709 } else {
710 match json {
711 Value::Array(members) => {
712 let result = if type_name.starts_with("Vec<") && type_name.ends_with('>') {
713 members
715 .iter()
716 .map(|v| from_json(v, &type_name[4..(type_name.len() - 1)], rti, ctx))
717 .collect::<Vec<_>>()
718 } else {
719 let mut result = Vec::new();
721 let type_name = &type_name[1..(type_name.len() - 1)];
722 let mut prev_elem_end = 0;
723 let mut members_iter = members.into_iter();
724 while let Some((next_elem_begin, next_elem_end)) =
725 find_next_type_in_tuple(type_name, prev_elem_end)
726 {
727 match members_iter.next() {
728 Some(elem) => result.push(from_json(
729 elem,
730 &type_name[next_elem_begin..next_elem_end],
731 rti,
732 ctx,
733 )),
734 None => break,
736 }
737 prev_elem_end = next_elem_end;
738 }
739 result
740 };
741 format!("[{}]", separated(" ", result))
743 }
744 Value::Object(map) => {
745 unreachable!("Invalid map {:?} found for type {}", map, type_name)
746 }
747 other => other.to_string(),
748 }
749 }
750}
751
752fn from_json_fields<C>(
753 v: &Value,
754 f_names: &[&'static str],
755 f_types: &[&'static str],
756 rti: &ReflectedTypeInfo,
757 ctx: &mut C,
758) -> String
759where
760 C: TestDeserializeContext,
761{
762 match v {
763 Value::Object(map) if !f_names.is_empty() => {
769 let mut fields = Vec::with_capacity(f_types.len());
770 for (name, typ) in f_names.iter().zip(f_types.iter()) {
771 fields.push(from_json(&map[*name], typ, rti, ctx))
772 }
773 separated(" ", fields).to_string()
774 }
775 Value::Array(inner) if f_types.len() > 1 => {
778 let mut fields = Vec::with_capacity(f_types.len());
779 for (v, typ) in inner.iter().zip(f_types.iter()) {
780 fields.push(from_json(v, typ, rti, ctx))
781 }
782 separated(" ", fields).to_string()
783 }
784 other => from_json(other, f_types.first().unwrap(), rti, ctx),
786 }
787}
788
789fn normalize_type_name(type_name: &str) -> (String, bool) {
793 let mut type_name = &type_name.replace([' ', '\n'], "")[..];
795 let mut option_found = false;
796 loop {
803 if type_name.starts_with("Option<") && type_name.ends_with('>') {
804 option_found = true;
805 type_name = &type_name[7..(type_name.len() - 1)]
806 } else if type_name.starts_with("Box<") && type_name.ends_with('>') {
807 type_name = &type_name[4..(type_name.len() - 1)]
808 } else {
809 break;
810 }
811 }
812
813 (type_name.to_string(), option_found)
814}
815
816fn find_next_type_in_tuple(type_name: &str, prev_elem_end: usize) -> Option<(usize, usize)> {
817 let current_elem_begin = if prev_elem_end > 0 {
818 prev_elem_end + 1
820 } else {
821 prev_elem_end
822 };
823 if current_elem_begin >= type_name.len() {
824 return None;
825 }
826 let mut i = current_elem_begin;
830 let mut it = type_name.chars().skip(current_elem_begin).peekable();
831 let mut paren_level = 0;
832 let mut bracket_level = 0;
833 while i < type_name.len()
834 && !(paren_level == 0 && bracket_level == 0 && *it.peek().unwrap() == ',')
835 {
836 if *it.peek().unwrap() == '(' {
837 paren_level += 1;
838 } else if *it.peek().unwrap() == ')' {
839 paren_level -= 1;
840 }
841 if *it.peek().unwrap() == '<' {
842 bracket_level += 1;
843 } else if *it.peek().unwrap() == '>' {
844 bracket_level -= 1;
845 }
846 i += 1;
847 it.next();
848 }
849
850 Some((current_elem_begin, i))
851}
852
853