1use std::{borrow::Cow, collections::HashMap, convert::TryInto, fmt, str::FromStr};
2
3use prost::bytes::Bytes;
4use serde::de::{DeserializeSeed, Deserializer, Error, IgnoredAny, MapAccess, SeqAccess, Visitor};
5
6use crate::{
7 dynamic::{serde::DeserializeOptions, DynamicMessage, MapKey, Value},
8 EnumDescriptor, Kind, MessageDescriptor, ReflectMessage,
9};
10
11use super::{
12 deserialize_enum, deserialize_message, FieldDescriptorSeed, OptionalFieldDescriptorSeed,
13};
14
15pub struct KindSeed<'a>(pub &'a Kind, pub &'a DeserializeOptions);
16
17impl<'de> DeserializeSeed<'de> for KindSeed<'_> {
18 type Value = Option<Value>;
19
20 fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
21 where
22 D: Deserializer<'de>,
23 {
24 match self.0 {
25 Kind::Double => Ok(Some(Value::F64(
26 deserializer.deserialize_any(DoubleVisitor)?,
27 ))),
28 Kind::Float => Ok(Some(Value::F32(
29 deserializer.deserialize_any(FloatVisitor)?,
30 ))),
31 Kind::Int32 | Kind::Sint32 | Kind::Sfixed32 => Ok(Some(Value::I32(
32 deserializer.deserialize_any(Int32Visitor)?,
33 ))),
34 Kind::Int64 | Kind::Sint64 | Kind::Sfixed64 => Ok(Some(Value::I64(
35 deserializer.deserialize_any(Int64Visitor)?,
36 ))),
37 Kind::Uint32 | Kind::Fixed32 => Ok(Some(Value::U32(
38 deserializer.deserialize_any(Uint32Visitor)?,
39 ))),
40 Kind::Uint64 | Kind::Fixed64 => Ok(Some(Value::U64(
41 deserializer.deserialize_any(Uint64Visitor)?,
42 ))),
43 Kind::Bool => Ok(Some(Value::Bool(
44 deserializer.deserialize_any(BoolVisitor)?,
45 ))),
46 Kind::String => Ok(Some(Value::String(
47 deserializer.deserialize_string(StringVisitor)?,
48 ))),
49 Kind::Bytes => Ok(Some(Value::Bytes(
50 deserializer.deserialize_str(BytesVisitor)?,
51 ))),
52 Kind::Message(desc) => Ok(Some(Value::Message(deserialize_message(
53 desc,
54 deserializer,
55 self.1,
56 )?))),
57 Kind::Enum(desc) => {
58 Ok(deserialize_enum(desc, deserializer, self.1)?.map(Value::EnumNumber))
59 }
60 }
61 }
62}
63
64pub struct ListVisitor<'a>(pub &'a Kind, pub &'a DeserializeOptions);
65pub struct MapVisitor<'a>(pub &'a Kind, pub &'a DeserializeOptions);
66pub struct DoubleVisitor;
67pub struct FloatVisitor;
68pub struct Int32Visitor;
69pub struct Uint32Visitor;
70pub struct Int64Visitor;
71pub struct Uint64Visitor;
72pub struct StringVisitor;
73pub struct BoolVisitor;
74pub struct BytesVisitor;
75pub struct MessageVisitor<'a>(pub &'a MessageDescriptor, pub &'a DeserializeOptions);
76pub struct MessageVisitorInner<'a>(pub &'a mut DynamicMessage, pub &'a DeserializeOptions);
77pub struct EnumVisitor<'a>(pub &'a EnumDescriptor, pub &'a DeserializeOptions);
78
79impl<'de> Visitor<'de> for ListVisitor<'_> {
80 type Value = Vec<Value>;
81
82 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
83 write!(f, "a list")
84 }
85
86 #[inline]
87 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
88 where
89 A: SeqAccess<'de>,
90 {
91 let mut result = Vec::with_capacity(seq.size_hint().unwrap_or(0));
92
93 while let Some(value) = seq.next_element_seed(KindSeed(self.0, self.1))? {
94 if let Some(value) = value {
95 result.push(value)
96 }
97 }
98
99 Ok(result)
100 }
101}
102
103impl<'de> Visitor<'de> for MapVisitor<'_> {
104 type Value = HashMap<MapKey, Value>;
105
106 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
107 write!(f, "a map")
108 }
109
110 fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
111 where
112 A: MapAccess<'de>,
113 {
114 let mut result = HashMap::with_capacity(map.size_hint().unwrap_or(0));
115
116 let map_entry_message = self.0.as_message().unwrap();
117 let key_kind = map_entry_message.map_entry_key_field().kind();
118 let value_desc = map_entry_message.map_entry_value_field();
119
120 while let Some(key_str) = map.next_key::<Cow<str>>()? {
121 let key = match key_kind {
122 Kind::Int32 | Kind::Sint32 | Kind::Sfixed32 => {
123 MapKey::I32(i32::from_str(key_str.as_ref()).map_err(Error::custom)?)
124 }
125 Kind::Int64 | Kind::Sint64 | Kind::Sfixed64 => {
126 MapKey::I64(i64::from_str(key_str.as_ref()).map_err(Error::custom)?)
127 }
128 Kind::Uint32 | Kind::Fixed32 => {
129 MapKey::U32(u32::from_str(key_str.as_ref()).map_err(Error::custom)?)
130 }
131 Kind::Uint64 | Kind::Fixed64 => {
132 MapKey::U64(u64::from_str(key_str.as_ref()).map_err(Error::custom)?)
133 }
134 Kind::Bool => {
135 MapKey::Bool(bool::from_str(key_str.as_ref()).map_err(Error::custom)?)
136 }
137 Kind::String => MapKey::String(key_str.into_owned()),
138 _ => unreachable!("invalid type for map key"),
139 };
140
141 let value = map.next_value_seed(FieldDescriptorSeed(&value_desc, self.1))?;
142 if let Some(value) = value {
143 result.insert(key, value);
144 }
145 }
146
147 Ok(result)
148 }
149}
150
151impl Visitor<'_> for DoubleVisitor {
152 type Value = f64;
153
154 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
155 write!(f, "a 64-bit floating point value")
156 }
157
158 #[inline]
159 fn visit_f64<E>(self, v: f64) -> Result<Self::Value, E>
160 where
161 E: Error,
162 {
163 Ok(v)
164 }
165
166 #[inline]
167 fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
168 where
169 E: Error,
170 {
171 Ok(v as Self::Value)
172 }
173
174 #[inline]
175 fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
176 where
177 E: Error,
178 {
179 Ok(v as Self::Value)
180 }
181
182 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
183 where
184 E: Error,
185 {
186 match f64::from_str(v) {
187 Ok(value) => Ok(value),
188 Err(_) if v == "Infinity" => Ok(f64::INFINITY),
189 Err(_) if v == "-Infinity" => Ok(f64::NEG_INFINITY),
190 Err(_) if v == "NaN" => Ok(f64::NAN),
191 Err(err) => Err(Error::custom(err)),
192 }
193 }
194}
195
196impl Visitor<'_> for FloatVisitor {
197 type Value = f32;
198
199 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
200 write!(f, "a 32-bit floating point value")
201 }
202
203 #[inline]
204 fn visit_f32<E>(self, v: f32) -> Result<Self::Value, E>
205 where
206 E: Error,
207 {
208 Ok(v)
209 }
210
211 #[inline]
212 fn visit_f64<E>(self, v: f64) -> Result<Self::Value, E>
213 where
214 E: Error,
215 {
216 if v < (f32::MIN as f64) || v > (f32::MAX as f64) {
217 Err(Error::custom("float value out of range"))
218 } else {
219 Ok(v as f32)
220 }
221 }
222
223 #[inline]
224 fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
225 where
226 E: Error,
227 {
228 Ok(v as Self::Value)
229 }
230
231 #[inline]
232 fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
233 where
234 E: Error,
235 {
236 Ok(v as Self::Value)
237 }
238
239 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
240 where
241 E: Error,
242 {
243 match f32::from_str(v) {
244 Ok(value) => Ok(value),
245 Err(_) if v == "Infinity" => Ok(f32::INFINITY),
246 Err(_) if v == "-Infinity" => Ok(f32::NEG_INFINITY),
247 Err(_) if v == "NaN" => Ok(f32::NAN),
248 Err(err) => Err(Error::custom(err)),
249 }
250 }
251}
252
253impl Visitor<'_> for Int32Visitor {
254 type Value = i32;
255
256 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
257 write!(f, "a 32-bit signed integer")
258 }
259
260 #[inline]
261 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
262 where
263 E: Error,
264 {
265 v.parse().map_err(Error::custom)
266 }
267
268 #[inline]
269 fn visit_i32<E>(self, v: i32) -> Result<Self::Value, E>
270 where
271 E: Error,
272 {
273 Ok(v)
274 }
275
276 #[inline]
277 fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
278 where
279 E: Error,
280 {
281 v.try_into().map_err(Error::custom)
282 }
283
284 #[inline]
285 fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
286 where
287 E: Error,
288 {
289 v.try_into().map_err(Error::custom)
290 }
291
292 fn visit_f64<E>(self, v: f64) -> Result<Self::Value, E>
293 where
294 E: Error,
295 {
296 if v.fract() != 0.0 {
297 return Err(Error::custom("expected integer value"));
298 }
299
300 if v < (i32::MIN as f64) || v > (i32::MAX as f64) {
301 return Err(Error::custom("float value out of range"));
302 }
303
304 Ok(v as i32)
305 }
306}
307
308impl Visitor<'_> for Uint32Visitor {
309 type Value = u32;
310
311 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
312 write!(f, "a 32-bit unsigned integer or decimal string")
313 }
314
315 #[inline]
316 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
317 where
318 E: Error,
319 {
320 v.parse().map_err(Error::custom)
321 }
322
323 #[inline]
324 fn visit_u32<E>(self, v: u32) -> Result<Self::Value, E>
325 where
326 E: Error,
327 {
328 Ok(v)
329 }
330
331 #[inline]
332 fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
333 where
334 E: Error,
335 {
336 v.try_into().map_err(Error::custom)
337 }
338
339 #[inline]
340 fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
341 where
342 E: Error,
343 {
344 v.try_into().map_err(Error::custom)
345 }
346
347 fn visit_f64<E>(self, v: f64) -> Result<Self::Value, E>
348 where
349 E: Error,
350 {
351 if v.fract() != 0.0 {
352 return Err(Error::custom("expected integer value"));
353 }
354
355 if v < (u32::MIN as f64) || v > (u32::MAX as f64) {
356 return Err(Error::custom("float value out of range"));
357 }
358
359 Ok(v as u32)
360 }
361}
362
363impl Visitor<'_> for Int64Visitor {
364 type Value = i64;
365
366 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
367 write!(f, "a 64-bit signed integer or decimal string")
368 }
369
370 #[inline]
371 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
372 where
373 E: Error,
374 {
375 v.parse().map_err(Error::custom)
376 }
377
378 #[inline]
379 fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
380 where
381 E: Error,
382 {
383 Ok(v)
384 }
385
386 #[inline]
387 fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
388 where
389 E: Error,
390 {
391 v.try_into().map_err(Error::custom)
392 }
393
394 fn visit_f64<E>(self, v: f64) -> Result<Self::Value, E>
395 where
396 E: Error,
397 {
398 if v.fract() != 0.0 {
399 return Err(Error::custom("expected integer value"));
400 }
401
402 if v < (i64::MIN as f64) || v > (i64::MAX as f64) {
403 return Err(Error::custom("float value out of range"));
404 }
405
406 Ok(v as i64)
407 }
408}
409
410impl Visitor<'_> for Uint64Visitor {
411 type Value = u64;
412
413 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
414 write!(f, "a 64-bit unsigned integer or decimal string")
415 }
416
417 #[inline]
418 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
419 where
420 E: Error,
421 {
422 v.parse().map_err(Error::custom)
423 }
424
425 #[inline]
426 fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
427 where
428 E: Error,
429 {
430 Ok(v)
431 }
432
433 #[inline]
434 fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
435 where
436 E: Error,
437 {
438 v.try_into().map_err(Error::custom)
439 }
440
441 fn visit_f64<E>(self, v: f64) -> Result<Self::Value, E>
442 where
443 E: Error,
444 {
445 if v.fract() != 0.0 {
446 return Err(Error::custom("expected integer value"));
447 }
448
449 if v < (u64::MIN as f64) || v > (u64::MAX as f64) {
450 return Err(Error::custom("float value out of range"));
451 }
452
453 Ok(v as u64)
454 }
455}
456
457impl Visitor<'_> for StringVisitor {
458 type Value = String;
459
460 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
461 write!(f, "a string")
462 }
463
464 #[inline]
465 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
466 where
467 E: Error,
468 {
469 Ok(v.to_owned())
470 }
471
472 #[inline]
473 fn visit_string<E>(self, v: String) -> Result<Self::Value, E>
474 where
475 E: Error,
476 {
477 Ok(v)
478 }
479}
480
481impl Visitor<'_> for BoolVisitor {
482 type Value = bool;
483
484 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
485 write!(f, "a boolean")
486 }
487
488 #[inline]
489 fn visit_bool<E>(self, v: bool) -> Result<Self::Value, E>
490 where
491 E: Error,
492 {
493 Ok(v)
494 }
495}
496
497impl Visitor<'_> for BytesVisitor {
498 type Value = Bytes;
499
500 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
501 write!(f, "a base64-encoded string")
502 }
503
504 #[inline]
505 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
506 where
507 E: Error,
508 {
509 use base64::{
510 alphabet,
511 engine::DecodePaddingMode,
512 engine::{GeneralPurpose, GeneralPurposeConfig},
513 DecodeError, Engine,
514 };
515
516 const CONFIG: GeneralPurposeConfig = GeneralPurposeConfig::new()
517 .with_decode_allow_trailing_bits(true)
518 .with_decode_padding_mode(DecodePaddingMode::Indifferent);
519 const STANDARD: GeneralPurpose = GeneralPurpose::new(&alphabet::STANDARD, CONFIG);
520 const URL_SAFE: GeneralPurpose = GeneralPurpose::new(&alphabet::URL_SAFE, CONFIG);
521
522 let mut buf = Vec::new();
523 match STANDARD.decode_vec(v, &mut buf) {
524 Ok(()) => Ok(buf.into()),
525 Err(DecodeError::InvalidByte(_, b'-')) | Err(DecodeError::InvalidByte(_, b'_')) => {
526 buf.clear();
527 match URL_SAFE.decode_vec(v, &mut buf) {
528 Ok(()) => Ok(buf.into()),
529 Err(err) => Err(Error::custom(format!("invalid base64: {}", err))),
530 }
531 }
532 Err(err) => Err(Error::custom(format!("invalid base64: {}", err))),
533 }
534 }
535}
536
537impl<'de> Visitor<'de> for MessageVisitor<'_> {
538 type Value = DynamicMessage;
539
540 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
541 write!(f, "a map")
542 }
543
544 fn visit_map<A>(self, map: A) -> Result<Self::Value, A::Error>
545 where
546 A: MapAccess<'de>,
547 {
548 let mut message = DynamicMessage::new(self.0.clone());
549
550 MessageVisitorInner(&mut message, self.1).visit_map(map)?;
551
552 Ok(message)
553 }
554}
555
556impl<'de> Visitor<'de> for MessageVisitorInner<'_> {
557 type Value = ();
558
559 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
560 write!(f, "a map")
561 }
562
563 fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
564 where
565 A: MapAccess<'de>,
566 {
567 let desc = self.0.descriptor();
568 while let Some(key) = map.next_key::<Cow<str>>()? {
569 if let Some(field) = desc
570 .get_field_by_json_name(key.as_ref())
571 .or_else(|| desc.get_field_by_name(key.as_ref()))
572 {
573 if let Some(value) =
574 map.next_value_seed(OptionalFieldDescriptorSeed(&field, self.1))?
575 {
576 if let Some(oneof_desc) = field.containing_oneof() {
577 for oneof_field in oneof_desc.fields() {
578 if self.0.has_field(&oneof_field) {
579 return Err(Error::custom(format!(
580 "multiple fields provided for oneof '{}'",
581 oneof_desc.name()
582 )));
583 }
584 }
585 }
586
587 self.0.set_field(&field, value);
588 }
589 } else if let Some(extension_desc) = desc.get_extension_by_json_name(key.as_ref()) {
590 if let Some(value) =
591 map.next_value_seed(OptionalFieldDescriptorSeed(&extension_desc, self.1))?
592 {
593 self.0.set_extension(&extension_desc, value);
594 }
595 } else if self.1.deny_unknown_fields {
596 return Err(Error::custom(format!("unrecognized field name '{}'", key)));
597 } else {
598 let _ = map.next_value::<IgnoredAny>()?;
599 }
600 }
601
602 Ok(())
603 }
604}
605
606impl Visitor<'_> for EnumVisitor<'_> {
607 type Value = Option<i32>;
608
609 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
610 write!(f, "a string or integer")
611 }
612
613 #[inline]
614 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
615 where
616 E: Error,
617 {
618 match self.0.get_value_by_name(v) {
619 Some(e) => Ok(Some(e.number())),
620 None => {
621 if self.1.deny_unknown_fields {
622 Err(Error::custom(format!("unrecognized enum value '{}'", v)))
623 } else {
624 Ok(None)
625 }
626 }
627 }
628 }
629
630 #[inline]
631 fn visit_i32<E>(self, v: i32) -> Result<Self::Value, E>
632 where
633 E: Error,
634 {
635 Ok(Some(v))
636 }
637
638 #[inline]
639 fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
640 where
641 E: Error,
642 {
643 self.visit_i32(v.try_into().map_err(Error::custom)?)
644 }
645
646 #[inline]
647 fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
648 where
649 E: Error,
650 {
651 self.visit_i32(v.try_into().map_err(Error::custom)?)
652 }
653}