1use std::collections::BTreeMap;
15
16use itertools::Itertools;
17pub use mz_lowertest_derive::MzReflect;
18use mz_ore::result::ResultExt;
19use mz_ore::str::{StrExt, separated};
20use mz_ore::treat_as_equal::TreatAsEqual;
21use proc_macro2::{Delimiter, TokenStream, TokenTree};
22use serde::de::DeserializeOwned;
23use serde_json::Value;
24
25pub trait MzReflect {
31 fn add_to_reflected_type_info(rti: &mut ReflectedTypeInfo);
37}
38
39impl<T: MzReflect> MzReflect for Vec<T> {
40 fn add_to_reflected_type_info(rti: &mut ReflectedTypeInfo) {
41 T::add_to_reflected_type_info(rti);
42 }
43}
44
45impl<T> MzReflect for TreatAsEqual<T> {
46 fn add_to_reflected_type_info(_rti: &mut ReflectedTypeInfo) {}
47}
48
49#[derive(Debug, Default)]
54pub struct ReflectedTypeInfo {
55 pub enum_dict:
56 BTreeMap<&'static str, BTreeMap<&'static str, (Vec<&'static str>, Vec<&'static str>)>>,
57 pub struct_dict: BTreeMap<&'static str, (Vec<&'static str>, Vec<&'static str>)>,
58}
59
60pub fn tokenize(s: &str) -> Result<TokenStream, String> {
66 s.parse::<TokenStream>().map_err_to_string_with_causes()
67}
68
69pub fn unquote(s: &str) -> String {
71 if s.starts_with('"') && s.ends_with('"') {
72 s[1..(s.len() - 1)].replace("\\\"", "\"")
73 } else {
74 s.to_string()
75 }
76}
77
78pub fn deserialize_optional_generic<D, I>(
82 stream_iter: &mut I,
83 type_name: &'static str,
84) -> Result<Option<D>, String>
85where
86 D: DeserializeOwned + MzReflect,
87 I: Iterator<Item = TokenTree>,
88{
89 deserialize_optional(
90 stream_iter,
91 type_name,
92 &mut GenericTestDeserializeContext::default(),
93 )
94}
95
96pub fn deserialize_optional<D, I, C>(
104 stream_iter: &mut I,
105 type_name: &'static str,
106 ctx: &mut C,
107) -> Result<Option<D>, String>
108where
109 C: TestDeserializeContext,
110 D: DeserializeOwned + MzReflect,
111 I: Iterator<Item = TokenTree>,
112{
113 let mut rti = ReflectedTypeInfo::default();
114 D::add_to_reflected_type_info(&mut rti);
115 match to_json(stream_iter, type_name, &rti, ctx)? {
116 Some(j) => Ok(Some(serde_json::from_str::<D>(&j).map_err(|e| {
117 format!("String while serializing: {}\nOriginal JSON: {}", e, j)
118 })?)),
119 None => Ok(None),
120 }
121}
122
123pub fn deserialize_generic<D, I>(stream_iter: &mut I, type_name: &'static str) -> Result<D, String>
125where
126 D: DeserializeOwned + MzReflect,
127 I: Iterator<Item = TokenTree>,
128{
129 deserialize(
130 stream_iter,
131 type_name,
132 &mut GenericTestDeserializeContext::default(),
133 )
134}
135
136pub fn deserialize<D, I, C>(
144 stream_iter: &mut I,
145 type_name: &'static str,
146 ctx: &mut C,
147) -> Result<D, String>
148where
149 C: TestDeserializeContext,
150 D: DeserializeOwned + MzReflect,
151 I: Iterator<Item = TokenTree>,
152{
153 deserialize_optional(stream_iter, type_name, ctx)?
154 .ok_or_else(|| format!("Empty spec for type {}", type_name))
155}
156
157pub fn to_json<I, C>(
185 stream_iter: &mut I,
186 type_name: &str,
187 rti: &ReflectedTypeInfo,
188 ctx: &mut C,
189) -> Result<Option<String>, String>
190where
191 C: TestDeserializeContext,
192 I: Iterator<Item = TokenTree>,
193{
194 let (type_name, option_found) = normalize_type_name(type_name);
195
196 if let Some((_, f_types)) = rti.struct_dict.get(&type_name[..]) {
198 if f_types.is_empty() {
199 return Ok(Some("null".to_string()));
200 }
201 }
202
203 if let Some(first_arg) = stream_iter.next() {
204 if option_found {
207 if let TokenTree::Ident(ident) = &first_arg {
208 if *ident == "null" {
209 return Ok(Some("null".to_string()));
210 }
211 }
212 }
213
214 if let Some(result) =
220 parse_as_enum_or_struct(first_arg.clone(), stream_iter, &type_name, rti, ctx)?
221 {
222 return Ok(Some(result));
223 }
224 if let Some(result) = ctx.override_syntax(first_arg.clone(), stream_iter, &type_name)? {
226 return Ok(Some(result));
227 }
228 match first_arg {
229 TokenTree::Group(group) => {
230 let mut inner_iter = group.stream().into_iter();
231 match group.delimiter() {
232 Delimiter::Bracket => {
233 if type_name.starts_with("Vec<") && type_name.ends_with('>') {
234 let vec = parse_as_vec(
236 &mut inner_iter,
237 &type_name[4..(type_name.len() - 1)],
238 rti,
239 ctx,
240 )?;
241 Ok(Some(format!("[{}]", separated(",", vec.iter()))))
242 } else if type_name.starts_with("[") && type_name.ends_with(']') {
243 let vec = parse_as_vec(
245 &mut inner_iter,
246 &type_name[1..(type_name.len() - 1)],
247 rti,
248 ctx,
249 )?;
250 Ok(Some(format!("[{}]", separated(",", vec.iter()))))
251 } else if type_name.starts_with('(') && type_name.ends_with(')') {
252 let vec = parse_as_tuple(
253 &mut inner_iter,
254 &type_name[1..(type_name.len() - 1)],
255 rti,
256 ctx,
257 )?;
258 Ok(Some(format!("[{}]", separated(",", vec.iter()))))
259 } else {
260 Err(format!(
261 "Object specified with brackets {:?} has unsupported type `{}`",
262 inner_iter.collect::<Vec<_>>(),
263 type_name
264 ))
265 }
266 }
267 delim => Err(format!(
268 "Object spec {:?} (type {}) has unsupported delimiter {:?}",
269 inner_iter.collect::<Vec<_>>(),
270 type_name,
271 delim
272 )),
273 }
274 }
275 TokenTree::Punct(punct) => {
276 match punct.as_char() {
277 ',' => to_json(stream_iter, &type_name, rti, ctx),
280 other => match to_json(stream_iter, &type_name, rti, ctx)? {
284 Some(result) => Ok(Some(format!("{}{}", other, result))),
285 None => Ok(Some(other.to_string())),
286 },
287 }
288 }
289 TokenTree::Ident(ident) => {
290 if type_name == "String" && ident != "null" {
294 Ok(Some(ident.to_string().quoted().to_string()))
295 } else {
296 Ok(Some(ident.to_string()))
297 }
298 }
299 TokenTree::Literal(literal) => Ok(Some(literal.to_string())),
300 }
301 } else {
302 Ok(None)
303 }
304}
305
306pub trait TestDeserializeContext {
314 fn override_syntax<I>(
324 &mut self,
325 first_arg: TokenTree,
326 rest_of_stream: &mut I,
327 type_name: &str,
328 ) -> Result<Option<String>, String>
329 where
330 I: Iterator<Item = TokenTree>;
331
332 fn reverse_syntax_override(&mut self, json: &Value, type_name: &str) -> Option<String>;
338}
339
340#[derive(Default)]
344struct GenericTestDeserializeContext;
345
346impl TestDeserializeContext for GenericTestDeserializeContext {
347 fn override_syntax<I>(
348 &mut self,
349 _first_arg: TokenTree,
350 _rest_of_stream: &mut I,
351 _type_name: &str,
352 ) -> Result<Option<String>, String>
353 where
354 I: Iterator<Item = TokenTree>,
355 {
356 Ok(None)
357 }
358
359 fn reverse_syntax_override(&mut self, _: &Value, _: &str) -> Option<String> {
360 None
361 }
362}
363
364fn parse_as_vec<I, C>(
368 stream_iter: &mut I,
369 type_name: &str,
370 rti: &ReflectedTypeInfo,
371 ctx: &mut C,
372) -> Result<Vec<String>, String>
373where
374 C: TestDeserializeContext,
375 I: Iterator<Item = TokenTree>,
376{
377 let mut result = Vec::new();
378 while let Some(element) = to_json(stream_iter, type_name, rti, ctx)? {
379 result.push(element);
380 }
381 Ok(result)
382}
383
384fn parse_as_tuple<I, C>(
388 stream_iter: &mut I,
389 type_name: &str,
390 rti: &ReflectedTypeInfo,
391 ctx: &mut C,
392) -> Result<Vec<String>, String>
393where
394 C: TestDeserializeContext,
395 I: Iterator<Item = TokenTree>,
396{
397 let mut prev_elem_end = 0;
398 let mut result = Vec::new();
399 while let Some((next_elem_begin, next_elem_end)) =
400 find_next_type_in_tuple(type_name, prev_elem_end)
401 {
402 match to_json(
403 stream_iter,
404 &type_name[next_elem_begin..next_elem_end],
405 rti,
406 ctx,
407 )? {
408 Some(elem) => result.push(elem),
409 None => break,
412 }
413 prev_elem_end = next_elem_end;
414 }
415 Ok(result)
416}
417
418fn parse_as_enum_or_struct<I, C>(
424 first_arg: TokenTree,
425 rest_of_stream: &mut I,
426 type_name: &str,
427 rti: &ReflectedTypeInfo,
428 ctx: &mut C,
429) -> Result<Option<String>, String>
430where
431 C: TestDeserializeContext,
432 I: Iterator<Item = TokenTree>,
433{
434 if rti.enum_dict.contains_key(type_name) || rti.struct_dict.contains_key(type_name) {
435 match first_arg {
440 TokenTree::Group(group) if group.delimiter() == Delimiter::Parenthesis => {
441 let mut inner_iter = group.stream().into_iter();
442 match inner_iter.next() {
443 Some(first_arg) => parse_as_enum_or_struct_inner(
445 first_arg,
446 &mut inner_iter,
447 type_name,
448 rti,
449 ctx,
450 ),
451 None => Ok(None),
452 }
453 }
454 TokenTree::Punct(punct) => {
455 let mut consecutive_punct = Vec::new();
460 while let Some(token) = rest_of_stream.next() {
461 consecutive_punct.push(token);
462 match &consecutive_punct[consecutive_punct.len() - 1] {
463 TokenTree::Punct(_) => {}
464 _ => {
465 break;
466 }
467 }
468 }
469 parse_as_enum_or_struct_inner(
470 TokenTree::Punct(punct),
471 &mut consecutive_punct.into_iter(),
472 type_name,
473 rti,
474 ctx,
475 )
476 }
477 other => {
478 parse_as_enum_or_struct_inner(other, &mut std::iter::empty(), type_name, rti, ctx)
482 }
483 }
484 } else {
485 Ok(None)
486 }
487}
488
489fn parse_as_enum_or_struct_inner<I, C>(
490 first_arg: TokenTree,
491 rest_of_stream: &mut I,
492 type_name: &str,
493 rti: &ReflectedTypeInfo,
494 ctx: &mut C,
495) -> Result<Option<String>, String>
496where
497 C: TestDeserializeContext,
498 I: Iterator<Item = TokenTree>,
499{
500 if let Some(result) = ctx.override_syntax(first_arg.clone(), rest_of_stream, type_name)? {
501 Ok(Some(result))
502 } else if let Some((f_names, f_types)) = rti.struct_dict.get(type_name).map(|r| r.clone()) {
503 Ok(Some(to_json_fields(
504 type_name,
505 &mut (&mut std::iter::once(first_arg)).chain(rest_of_stream),
506 f_names,
507 f_types,
508 rti,
509 ctx,
510 )?))
511 } else if let TokenTree::Ident(ident) = first_arg {
512 Ok(Some(to_json_generic_enum(
513 ident.to_string(),
514 rest_of_stream,
515 type_name,
516 rti,
517 ctx,
518 )?))
519 } else {
520 Ok(None)
521 }
522}
523
524fn to_json_generic_enum<I, C>(
526 variant_snake_case: String,
527 rest_of_stream: &mut I,
528 type_name: &str,
529 rti: &ReflectedTypeInfo,
530 ctx: &mut C,
531) -> Result<String, String>
532where
533 C: TestDeserializeContext,
534 I: Iterator<Item = TokenTree>,
535{
536 let variant_camel_case = variant_snake_case
538 .split('_')
539 .map(|s| {
540 let mut chars = s.chars();
541 let result = chars
542 .next()
543 .map(|c| c.to_uppercase().chain(chars).collect::<String>())
544 .unwrap_or_else(String::new);
545 result
546 })
547 .collect::<Vec<_>>()
548 .concat();
549 let (f_names, f_types) = rti
550 .enum_dict
551 .get(type_name)
552 .unwrap()
553 .get(&variant_camel_case[..])
554 .map(|v| v.clone())
555 .ok_or_else(|| {
556 format!(
557 "{}::{} is not a supported enum.",
558 type_name, variant_camel_case
559 )
560 })?;
561 if f_types.is_empty() {
564 Ok(format!("\"{}\"", variant_camel_case))
566 } else {
567 let fields = to_json_fields(
568 &variant_camel_case,
569 rest_of_stream,
570 f_names,
571 f_types,
572 rti,
573 ctx,
574 )?;
575 Ok(format!("{{\"{}\":{}}}", variant_camel_case, fields))
576 }
577}
578
579fn to_json_fields<I, C>(
585 debug_name: &str,
586 stream_iter: &mut I,
587 f_names: Vec<&'static str>,
588 f_types: Vec<&'static str>,
589 rti: &ReflectedTypeInfo,
590 ctx: &mut C,
591) -> Result<String, String>
592where
593 C: TestDeserializeContext,
594 I: Iterator<Item = TokenTree>,
595{
596 let mut f_values = Vec::new();
597 for t in f_types.iter() {
598 match to_json(stream_iter, t, rti, ctx)? {
599 Some(value) => f_values.push(value),
600 None => {
601 break;
602 }
603 }
604 }
605 if !f_names.is_empty() {
606 Ok(format!(
609 "{{{}}}",
610 separated(
611 ",",
612 f_names
613 .iter()
614 .take(f_values.len())
615 .zip_eq(f_values.into_iter())
616 .map(|(n, v)| format!("\"{}\":{}", n, v))
617 )
618 ))
619 } else {
620 if f_types.len() == 1 {
624 Ok(f_values
625 .pop()
626 .ok_or_else(|| format!("Cannot use default value for {}", debug_name))?)
627 } else {
628 Ok(format!("[{}]", separated(",", f_values.into_iter())))
629 }
630 }
631}
632
633pub fn serialize_generic<M>(json: &Value, type_name: &str) -> String
637where
638 M: MzReflect,
639{
640 let mut rti = ReflectedTypeInfo::default();
641 M::add_to_reflected_type_info(&mut rti);
642 from_json(
643 json,
644 type_name,
645 &rti,
646 &mut GenericTestDeserializeContext::default(),
647 )
648}
649
650pub fn serialize<M, C>(json: &Value, type_name: &str, ctx: &mut C) -> String
651where
652 C: TestDeserializeContext,
653 M: MzReflect,
654{
655 let mut rti = ReflectedTypeInfo::default();
656 M::add_to_reflected_type_info(&mut rti);
657 from_json(json, type_name, &rti, ctx)
658}
659
660pub fn from_json<C>(json: &Value, type_name: &str, rti: &ReflectedTypeInfo, ctx: &mut C) -> String
667where
668 C: TestDeserializeContext,
669{
670 let (type_name, option_found) = normalize_type_name(type_name);
671 if option_found {
675 if let Value::Null = json {
676 return "null".to_string();
677 }
678 }
679 if let Some(result) = ctx.reverse_syntax_override(json, &type_name) {
680 return result;
681 }
682 if let Some((names, types)) = rti.struct_dict.get(&type_name[..]) {
683 if types.is_empty() {
684 "".to_string()
685 } else {
686 format!("({})", from_json_fields(json, names, types, rti, ctx))
687 }
688 } else if let Some(enum_dict) = rti.enum_dict.get(&type_name[..]) {
689 match json {
690 Value::String(s) => unquote(s),
692 Value::Object(map) => {
695 assert_eq!(
697 map.len(),
698 1,
699 "Multivariant instance {:?} found for enum {}",
700 map,
701 type_name
702 );
703 for (variant, data) in map.iter() {
704 if let Some((names, types)) = enum_dict.get(&variant[..]) {
705 return format!(
706 "({} {})",
707 variant,
708 from_json_fields(data, names, types, rti, ctx)
709 );
710 }
711 }
712 unreachable!()
713 }
714 _ => unreachable!("Invalid json {:?} for enum type {}", json, type_name),
715 }
716 } else {
717 match json {
718 Value::Array(members) => {
719 let result = if type_name.starts_with("Vec<") && type_name.ends_with('>') {
720 members
722 .iter()
723 .map(|v| from_json(v, &type_name[4..(type_name.len() - 1)], rti, ctx))
724 .collect::<Vec<_>>()
725 } else {
726 let mut result = Vec::new();
728 let type_name = &type_name[1..(type_name.len() - 1)];
729 let mut prev_elem_end = 0;
730 let mut members_iter = members.into_iter();
731 while let Some((next_elem_begin, next_elem_end)) =
732 find_next_type_in_tuple(type_name, prev_elem_end)
733 {
734 match members_iter.next() {
735 Some(elem) => result.push(from_json(
736 elem,
737 &type_name[next_elem_begin..next_elem_end],
738 rti,
739 ctx,
740 )),
741 None => break,
743 }
744 prev_elem_end = next_elem_end;
745 }
746 result
747 };
748 format!("[{}]", separated(" ", result))
750 }
751 Value::Object(map) => {
752 unreachable!("Invalid map {:?} found for type {}", map, type_name)
753 }
754 other => other.to_string(),
755 }
756 }
757}
758
759fn from_json_fields<C>(
760 v: &Value,
761 f_names: &[&'static str],
762 f_types: &[&'static str],
763 rti: &ReflectedTypeInfo,
764 ctx: &mut C,
765) -> String
766where
767 C: TestDeserializeContext,
768{
769 match v {
770 Value::Object(map) if !f_names.is_empty() => {
776 let mut fields = Vec::with_capacity(f_types.len());
777 for (name, typ) in f_names.iter().zip_eq(f_types.iter()) {
778 fields.push(from_json(&map[*name], typ, rti, ctx))
779 }
780 separated(" ", fields).to_string()
781 }
782 Value::Array(inner) if f_types.len() > 1 => {
785 let mut fields = Vec::with_capacity(f_types.len());
786 for (v, typ) in inner.iter().zip_eq(f_types.iter()) {
787 fields.push(from_json(v, typ, rti, ctx))
788 }
789 separated(" ", fields).to_string()
790 }
791 other => from_json(other, f_types.first().unwrap(), rti, ctx),
793 }
794}
795
796fn normalize_type_name(type_name: &str) -> (String, bool) {
800 let mut type_name = &type_name.replace([' ', '\n'], "")[..];
802 let mut option_found = false;
803 loop {
810 if type_name.starts_with("Option<") && type_name.ends_with('>') {
811 option_found = true;
812 type_name = &type_name[7..(type_name.len() - 1)]
813 } else if type_name.starts_with("Box<") && type_name.ends_with('>') {
814 type_name = &type_name[4..(type_name.len() - 1)]
815 } else {
816 break;
817 }
818 }
819
820 (type_name.to_string(), option_found)
821}
822
823fn find_next_type_in_tuple(type_name: &str, prev_elem_end: usize) -> Option<(usize, usize)> {
824 let current_elem_begin = if prev_elem_end > 0 {
825 prev_elem_end + 1
827 } else {
828 prev_elem_end
829 };
830 if current_elem_begin >= type_name.len() {
831 return None;
832 }
833 let mut i = current_elem_begin;
837 let mut it = type_name.chars().skip(current_elem_begin).peekable();
838 let mut paren_level = 0;
839 let mut bracket_level = 0;
840 while i < type_name.len()
841 && !(paren_level == 0 && bracket_level == 0 && *it.peek().unwrap() == ',')
842 {
843 if *it.peek().unwrap() == '(' {
844 paren_level += 1;
845 } else if *it.peek().unwrap() == ')' {
846 paren_level -= 1;
847 }
848 if *it.peek().unwrap() == '<' {
849 bracket_level += 1;
850 } else if *it.peek().unwrap() == '>' {
851 bracket_level -= 1;
852 }
853 i += 1;
854 it.next();
855 }
856
857 Some((current_elem_begin, i))
858}
859
860