prost_reflect/dynamic/serde/de/
kind.rs

1use std::{borrow::Cow, collections::HashMap, convert::TryInto, fmt, str::FromStr};
2
3use prost::bytes::Bytes;
4use serde::de::{DeserializeSeed, Deserializer, Error, IgnoredAny, MapAccess, SeqAccess, Visitor};
5
6use crate::{
7    dynamic::{serde::DeserializeOptions, DynamicMessage, MapKey, Value},
8    EnumDescriptor, Kind, MessageDescriptor, ReflectMessage,
9};
10
11use super::{
12    deserialize_enum, deserialize_message, FieldDescriptorSeed, OptionalFieldDescriptorSeed,
13};
14
15pub struct KindSeed<'a>(pub &'a Kind, pub &'a DeserializeOptions);
16
17impl<'de> DeserializeSeed<'de> for KindSeed<'_> {
18    type Value = Option<Value>;
19
20    fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
21    where
22        D: Deserializer<'de>,
23    {
24        match self.0 {
25            Kind::Double => Ok(Some(Value::F64(
26                deserializer.deserialize_any(DoubleVisitor)?,
27            ))),
28            Kind::Float => Ok(Some(Value::F32(
29                deserializer.deserialize_any(FloatVisitor)?,
30            ))),
31            Kind::Int32 | Kind::Sint32 | Kind::Sfixed32 => Ok(Some(Value::I32(
32                deserializer.deserialize_any(Int32Visitor)?,
33            ))),
34            Kind::Int64 | Kind::Sint64 | Kind::Sfixed64 => Ok(Some(Value::I64(
35                deserializer.deserialize_any(Int64Visitor)?,
36            ))),
37            Kind::Uint32 | Kind::Fixed32 => Ok(Some(Value::U32(
38                deserializer.deserialize_any(Uint32Visitor)?,
39            ))),
40            Kind::Uint64 | Kind::Fixed64 => Ok(Some(Value::U64(
41                deserializer.deserialize_any(Uint64Visitor)?,
42            ))),
43            Kind::Bool => Ok(Some(Value::Bool(
44                deserializer.deserialize_any(BoolVisitor)?,
45            ))),
46            Kind::String => Ok(Some(Value::String(
47                deserializer.deserialize_string(StringVisitor)?,
48            ))),
49            Kind::Bytes => Ok(Some(Value::Bytes(
50                deserializer.deserialize_str(BytesVisitor)?,
51            ))),
52            Kind::Message(desc) => Ok(Some(Value::Message(deserialize_message(
53                desc,
54                deserializer,
55                self.1,
56            )?))),
57            Kind::Enum(desc) => {
58                Ok(deserialize_enum(desc, deserializer, self.1)?.map(Value::EnumNumber))
59            }
60        }
61    }
62}
63
64pub struct ListVisitor<'a>(pub &'a Kind, pub &'a DeserializeOptions);
65pub struct MapVisitor<'a>(pub &'a Kind, pub &'a DeserializeOptions);
66pub struct DoubleVisitor;
67pub struct FloatVisitor;
68pub struct Int32Visitor;
69pub struct Uint32Visitor;
70pub struct Int64Visitor;
71pub struct Uint64Visitor;
72pub struct StringVisitor;
73pub struct BoolVisitor;
74pub struct BytesVisitor;
75pub struct MessageVisitor<'a>(pub &'a MessageDescriptor, pub &'a DeserializeOptions);
76pub struct MessageVisitorInner<'a>(pub &'a mut DynamicMessage, pub &'a DeserializeOptions);
77pub struct EnumVisitor<'a>(pub &'a EnumDescriptor, pub &'a DeserializeOptions);
78
79impl<'de> Visitor<'de> for ListVisitor<'_> {
80    type Value = Vec<Value>;
81
82    fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
83        write!(f, "a list")
84    }
85
86    #[inline]
87    fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
88    where
89        A: SeqAccess<'de>,
90    {
91        let mut result = Vec::with_capacity(seq.size_hint().unwrap_or(0));
92
93        while let Some(value) = seq.next_element_seed(KindSeed(self.0, self.1))? {
94            if let Some(value) = value {
95                result.push(value)
96            }
97        }
98
99        Ok(result)
100    }
101}
102
103impl<'de> Visitor<'de> for MapVisitor<'_> {
104    type Value = HashMap<MapKey, Value>;
105
106    fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
107        write!(f, "a map")
108    }
109
110    fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
111    where
112        A: MapAccess<'de>,
113    {
114        let mut result = HashMap::with_capacity(map.size_hint().unwrap_or(0));
115
116        let map_entry_message = self.0.as_message().unwrap();
117        let key_kind = map_entry_message.map_entry_key_field().kind();
118        let value_desc = map_entry_message.map_entry_value_field();
119
120        while let Some(key_str) = map.next_key::<Cow<str>>()? {
121            let key = match key_kind {
122                Kind::Int32 | Kind::Sint32 | Kind::Sfixed32 => {
123                    MapKey::I32(i32::from_str(key_str.as_ref()).map_err(Error::custom)?)
124                }
125                Kind::Int64 | Kind::Sint64 | Kind::Sfixed64 => {
126                    MapKey::I64(i64::from_str(key_str.as_ref()).map_err(Error::custom)?)
127                }
128                Kind::Uint32 | Kind::Fixed32 => {
129                    MapKey::U32(u32::from_str(key_str.as_ref()).map_err(Error::custom)?)
130                }
131                Kind::Uint64 | Kind::Fixed64 => {
132                    MapKey::U64(u64::from_str(key_str.as_ref()).map_err(Error::custom)?)
133                }
134                Kind::Bool => {
135                    MapKey::Bool(bool::from_str(key_str.as_ref()).map_err(Error::custom)?)
136                }
137                Kind::String => MapKey::String(key_str.into_owned()),
138                _ => unreachable!("invalid type for map key"),
139            };
140
141            let value = map.next_value_seed(FieldDescriptorSeed(&value_desc, self.1))?;
142            if let Some(value) = value {
143                result.insert(key, value);
144            }
145        }
146
147        Ok(result)
148    }
149}
150
151impl Visitor<'_> for DoubleVisitor {
152    type Value = f64;
153
154    fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
155        write!(f, "a 64-bit floating point value")
156    }
157
158    #[inline]
159    fn visit_f64<E>(self, v: f64) -> Result<Self::Value, E>
160    where
161        E: Error,
162    {
163        Ok(v)
164    }
165
166    #[inline]
167    fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
168    where
169        E: Error,
170    {
171        Ok(v as Self::Value)
172    }
173
174    #[inline]
175    fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
176    where
177        E: Error,
178    {
179        Ok(v as Self::Value)
180    }
181
182    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
183    where
184        E: Error,
185    {
186        match f64::from_str(v) {
187            Ok(value) => Ok(value),
188            Err(_) if v == "Infinity" => Ok(f64::INFINITY),
189            Err(_) if v == "-Infinity" => Ok(f64::NEG_INFINITY),
190            Err(_) if v == "NaN" => Ok(f64::NAN),
191            Err(err) => Err(Error::custom(err)),
192        }
193    }
194}
195
196impl Visitor<'_> for FloatVisitor {
197    type Value = f32;
198
199    fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
200        write!(f, "a 32-bit floating point value")
201    }
202
203    #[inline]
204    fn visit_f32<E>(self, v: f32) -> Result<Self::Value, E>
205    where
206        E: Error,
207    {
208        Ok(v)
209    }
210
211    #[inline]
212    fn visit_f64<E>(self, v: f64) -> Result<Self::Value, E>
213    where
214        E: Error,
215    {
216        if v < (f32::MIN as f64) || v > (f32::MAX as f64) {
217            Err(Error::custom("float value out of range"))
218        } else {
219            Ok(v as f32)
220        }
221    }
222
223    #[inline]
224    fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
225    where
226        E: Error,
227    {
228        Ok(v as Self::Value)
229    }
230
231    #[inline]
232    fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
233    where
234        E: Error,
235    {
236        Ok(v as Self::Value)
237    }
238
239    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
240    where
241        E: Error,
242    {
243        match f32::from_str(v) {
244            Ok(value) => Ok(value),
245            Err(_) if v == "Infinity" => Ok(f32::INFINITY),
246            Err(_) if v == "-Infinity" => Ok(f32::NEG_INFINITY),
247            Err(_) if v == "NaN" => Ok(f32::NAN),
248            Err(err) => Err(Error::custom(err)),
249        }
250    }
251}
252
253impl Visitor<'_> for Int32Visitor {
254    type Value = i32;
255
256    fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
257        write!(f, "a 32-bit signed integer")
258    }
259
260    #[inline]
261    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
262    where
263        E: Error,
264    {
265        v.parse().map_err(Error::custom)
266    }
267
268    #[inline]
269    fn visit_i32<E>(self, v: i32) -> Result<Self::Value, E>
270    where
271        E: Error,
272    {
273        Ok(v)
274    }
275
276    #[inline]
277    fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
278    where
279        E: Error,
280    {
281        v.try_into().map_err(Error::custom)
282    }
283
284    #[inline]
285    fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
286    where
287        E: Error,
288    {
289        v.try_into().map_err(Error::custom)
290    }
291
292    fn visit_f64<E>(self, v: f64) -> Result<Self::Value, E>
293    where
294        E: Error,
295    {
296        if v.fract() != 0.0 {
297            return Err(Error::custom("expected integer value"));
298        }
299
300        if v < (i32::MIN as f64) || v > (i32::MAX as f64) {
301            return Err(Error::custom("float value out of range"));
302        }
303
304        Ok(v as i32)
305    }
306}
307
308impl Visitor<'_> for Uint32Visitor {
309    type Value = u32;
310
311    fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
312        write!(f, "a 32-bit unsigned integer or decimal string")
313    }
314
315    #[inline]
316    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
317    where
318        E: Error,
319    {
320        v.parse().map_err(Error::custom)
321    }
322
323    #[inline]
324    fn visit_u32<E>(self, v: u32) -> Result<Self::Value, E>
325    where
326        E: Error,
327    {
328        Ok(v)
329    }
330
331    #[inline]
332    fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
333    where
334        E: Error,
335    {
336        v.try_into().map_err(Error::custom)
337    }
338
339    #[inline]
340    fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
341    where
342        E: Error,
343    {
344        v.try_into().map_err(Error::custom)
345    }
346
347    fn visit_f64<E>(self, v: f64) -> Result<Self::Value, E>
348    where
349        E: Error,
350    {
351        if v.fract() != 0.0 {
352            return Err(Error::custom("expected integer value"));
353        }
354
355        if v < (u32::MIN as f64) || v > (u32::MAX as f64) {
356            return Err(Error::custom("float value out of range"));
357        }
358
359        Ok(v as u32)
360    }
361}
362
363impl Visitor<'_> for Int64Visitor {
364    type Value = i64;
365
366    fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
367        write!(f, "a 64-bit signed integer or decimal string")
368    }
369
370    #[inline]
371    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
372    where
373        E: Error,
374    {
375        v.parse().map_err(Error::custom)
376    }
377
378    #[inline]
379    fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
380    where
381        E: Error,
382    {
383        Ok(v)
384    }
385
386    #[inline]
387    fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
388    where
389        E: Error,
390    {
391        v.try_into().map_err(Error::custom)
392    }
393
394    fn visit_f64<E>(self, v: f64) -> Result<Self::Value, E>
395    where
396        E: Error,
397    {
398        if v.fract() != 0.0 {
399            return Err(Error::custom("expected integer value"));
400        }
401
402        if v < (i64::MIN as f64) || v > (i64::MAX as f64) {
403            return Err(Error::custom("float value out of range"));
404        }
405
406        Ok(v as i64)
407    }
408}
409
410impl Visitor<'_> for Uint64Visitor {
411    type Value = u64;
412
413    fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
414        write!(f, "a 64-bit unsigned integer or decimal string")
415    }
416
417    #[inline]
418    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
419    where
420        E: Error,
421    {
422        v.parse().map_err(Error::custom)
423    }
424
425    #[inline]
426    fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
427    where
428        E: Error,
429    {
430        Ok(v)
431    }
432
433    #[inline]
434    fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
435    where
436        E: Error,
437    {
438        v.try_into().map_err(Error::custom)
439    }
440
441    fn visit_f64<E>(self, v: f64) -> Result<Self::Value, E>
442    where
443        E: Error,
444    {
445        if v.fract() != 0.0 {
446            return Err(Error::custom("expected integer value"));
447        }
448
449        if v < (u64::MIN as f64) || v > (u64::MAX as f64) {
450            return Err(Error::custom("float value out of range"));
451        }
452
453        Ok(v as u64)
454    }
455}
456
457impl Visitor<'_> for StringVisitor {
458    type Value = String;
459
460    fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
461        write!(f, "a string")
462    }
463
464    #[inline]
465    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
466    where
467        E: Error,
468    {
469        Ok(v.to_owned())
470    }
471
472    #[inline]
473    fn visit_string<E>(self, v: String) -> Result<Self::Value, E>
474    where
475        E: Error,
476    {
477        Ok(v)
478    }
479}
480
481impl Visitor<'_> for BoolVisitor {
482    type Value = bool;
483
484    fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
485        write!(f, "a boolean")
486    }
487
488    #[inline]
489    fn visit_bool<E>(self, v: bool) -> Result<Self::Value, E>
490    where
491        E: Error,
492    {
493        Ok(v)
494    }
495}
496
497impl Visitor<'_> for BytesVisitor {
498    type Value = Bytes;
499
500    fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
501        write!(f, "a base64-encoded string")
502    }
503
504    #[inline]
505    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
506    where
507        E: Error,
508    {
509        use base64::{
510            alphabet,
511            engine::DecodePaddingMode,
512            engine::{GeneralPurpose, GeneralPurposeConfig},
513            DecodeError, Engine,
514        };
515
516        const CONFIG: GeneralPurposeConfig = GeneralPurposeConfig::new()
517            .with_decode_allow_trailing_bits(true)
518            .with_decode_padding_mode(DecodePaddingMode::Indifferent);
519        const STANDARD: GeneralPurpose = GeneralPurpose::new(&alphabet::STANDARD, CONFIG);
520        const URL_SAFE: GeneralPurpose = GeneralPurpose::new(&alphabet::URL_SAFE, CONFIG);
521
522        let mut buf = Vec::new();
523        match STANDARD.decode_vec(v, &mut buf) {
524            Ok(()) => Ok(buf.into()),
525            Err(DecodeError::InvalidByte(_, b'-')) | Err(DecodeError::InvalidByte(_, b'_')) => {
526                buf.clear();
527                match URL_SAFE.decode_vec(v, &mut buf) {
528                    Ok(()) => Ok(buf.into()),
529                    Err(err) => Err(Error::custom(format!("invalid base64: {}", err))),
530                }
531            }
532            Err(err) => Err(Error::custom(format!("invalid base64: {}", err))),
533        }
534    }
535}
536
537impl<'de> Visitor<'de> for MessageVisitor<'_> {
538    type Value = DynamicMessage;
539
540    fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
541        write!(f, "a map")
542    }
543
544    fn visit_map<A>(self, map: A) -> Result<Self::Value, A::Error>
545    where
546        A: MapAccess<'de>,
547    {
548        let mut message = DynamicMessage::new(self.0.clone());
549
550        MessageVisitorInner(&mut message, self.1).visit_map(map)?;
551
552        Ok(message)
553    }
554}
555
556impl<'de> Visitor<'de> for MessageVisitorInner<'_> {
557    type Value = ();
558
559    fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
560        write!(f, "a map")
561    }
562
563    fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
564    where
565        A: MapAccess<'de>,
566    {
567        let desc = self.0.descriptor();
568        while let Some(key) = map.next_key::<Cow<str>>()? {
569            if let Some(field) = desc
570                .get_field_by_json_name(key.as_ref())
571                .or_else(|| desc.get_field_by_name(key.as_ref()))
572            {
573                if let Some(value) =
574                    map.next_value_seed(OptionalFieldDescriptorSeed(&field, self.1))?
575                {
576                    if let Some(oneof_desc) = field.containing_oneof() {
577                        for oneof_field in oneof_desc.fields() {
578                            if self.0.has_field(&oneof_field) {
579                                return Err(Error::custom(format!(
580                                    "multiple fields provided for oneof '{}'",
581                                    oneof_desc.name()
582                                )));
583                            }
584                        }
585                    }
586
587                    self.0.set_field(&field, value);
588                }
589            } else if let Some(extension_desc) = desc.get_extension_by_json_name(key.as_ref()) {
590                if let Some(value) =
591                    map.next_value_seed(OptionalFieldDescriptorSeed(&extension_desc, self.1))?
592                {
593                    self.0.set_extension(&extension_desc, value);
594                }
595            } else if self.1.deny_unknown_fields {
596                return Err(Error::custom(format!("unrecognized field name '{}'", key)));
597            } else {
598                let _ = map.next_value::<IgnoredAny>()?;
599            }
600        }
601
602        Ok(())
603    }
604}
605
606impl Visitor<'_> for EnumVisitor<'_> {
607    type Value = Option<i32>;
608
609    fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
610        write!(f, "a string or integer")
611    }
612
613    #[inline]
614    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
615    where
616        E: Error,
617    {
618        match self.0.get_value_by_name(v) {
619            Some(e) => Ok(Some(e.number())),
620            None => {
621                if self.1.deny_unknown_fields {
622                    Err(Error::custom(format!("unrecognized enum value '{}'", v)))
623                } else {
624                    Ok(None)
625                }
626            }
627        }
628    }
629
630    #[inline]
631    fn visit_i32<E>(self, v: i32) -> Result<Self::Value, E>
632    where
633        E: Error,
634    {
635        Ok(Some(v))
636    }
637
638    #[inline]
639    fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
640    where
641        E: Error,
642    {
643        self.visit_i32(v.try_into().map_err(Error::custom)?)
644    }
645
646    #[inline]
647    fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
648    where
649        E: Error,
650    {
651        self.visit_i32(v.try_into().map_err(Error::custom)?)
652    }
653}