prost_reflect/dynamic/serde/de/
wkt.rs

1use std::{
2    borrow::Cow,
3    collections::{BTreeMap, HashMap},
4    fmt,
5    marker::PhantomData,
6};
7
8use prost::Message;
9use serde::de::{
10    DeserializeSeed, Deserializer, Error, IgnoredAny, IntoDeserializer, MapAccess, SeqAccess,
11    Visitor,
12};
13
14use crate::{
15    dynamic::{
16        get_type_url_message_name,
17        serde::{
18            case::camel_case_to_snake_case, check_duration, check_timestamp, is_well_known_type,
19            DeserializeOptions,
20        },
21        DynamicMessage,
22    },
23    DescriptorPool,
24};
25
26use super::{deserialize_message, kind::MessageVisitorInner, MessageSeed};
27
28pub struct GoogleProtobufAnyVisitor<'a>(pub &'a DescriptorPool, pub &'a DeserializeOptions);
29pub struct GoogleProtobufNullVisitor<'a>(pub &'a DeserializeOptions);
30pub struct GoogleProtobufTimestampVisitor;
31pub struct GoogleProtobufDurationVisitor;
32pub struct GoogleProtobufFieldMaskVisitor;
33pub struct GoogleProtobufListVisitor;
34pub struct GoogleProtobufStructVisitor;
35pub struct GoogleProtobufValueVisitor;
36pub struct GoogleProtobufEmptyVisitor;
37
38impl<'de> Visitor<'de> for GoogleProtobufAnyVisitor<'_> {
39    type Value = prost_types::Any;
40
41    fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
42        write!(f, "a map")
43    }
44
45    fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
46    where
47        A: MapAccess<'de>,
48    {
49        let mut buffered_entries = HashMap::new();
50
51        let type_url = find_field(
52            &mut map,
53            &mut buffered_entries,
54            "@type",
55            PhantomData::<String>,
56        )?;
57
58        let message_name = get_type_url_message_name(&type_url).map_err(Error::custom)?;
59        let message_desc = self
60            .0
61            .get_message_by_name(message_name)
62            .ok_or_else(|| Error::custom(format!("message '{}' not found", message_name)))?;
63
64        let payload_message = if is_well_known_type(message_name) {
65            let payload_message = match buffered_entries.remove("value") {
66                Some(value) => {
67                    deserialize_message(&message_desc, value, self.1).map_err(Error::custom)?
68                }
69                None => find_field(
70                    &mut map,
71                    &mut buffered_entries,
72                    "value",
73                    MessageSeed(&message_desc, self.1),
74                )?,
75            };
76
77            if self.1.deny_unknown_fields {
78                if let Some(key) = buffered_entries.keys().next() {
79                    return Err(Error::custom(format!("unrecognized field name '{}'", key)));
80                }
81                if let Some(key) = map.next_key::<Cow<str>>()? {
82                    return Err(Error::custom(format!("unrecognized field name '{}'", key)));
83                }
84            } else {
85                drop(buffered_entries);
86                while map.next_entry::<IgnoredAny, IgnoredAny>()?.is_some() {}
87            }
88
89            payload_message
90        } else {
91            let mut payload_message = DynamicMessage::new(message_desc);
92
93            buffered_entries
94                .into_deserializer()
95                .deserialize_map(MessageVisitorInner(&mut payload_message, self.1))
96                .map_err(Error::custom)?;
97
98            MessageVisitorInner(&mut payload_message, self.1).visit_map(map)?;
99
100            payload_message
101        };
102
103        let value = payload_message.encode_to_vec();
104        Ok(prost_types::Any { type_url, value })
105    }
106}
107
108impl Visitor<'_> for GoogleProtobufNullVisitor<'_> {
109    type Value = Option<i32>;
110
111    fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
112        write!(f, "null")
113    }
114
115    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
116    where
117        E: Error,
118    {
119        if v == "NULL_VALUE" {
120            Ok(Some(0))
121        } else if self.0.deny_unknown_fields {
122            Err(Error::custom("expected null"))
123        } else {
124            Ok(None)
125        }
126    }
127
128    #[inline]
129    fn visit_unit<E>(self) -> Result<Self::Value, E>
130    where
131        E: Error,
132    {
133        Ok(Some(0))
134    }
135}
136
137impl Visitor<'_> for GoogleProtobufTimestampVisitor {
138    type Value = prost_types::Timestamp;
139
140    fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
141        write!(f, "a rfc3339 timestamp string")
142    }
143
144    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
145    where
146        E: Error,
147    {
148        validate_strict_rfc3339(v).map_err(Error::custom)?;
149
150        let timestamp: prost_types::Timestamp = v.parse().map_err(Error::custom)?;
151
152        check_timestamp(&timestamp).map_err(Error::custom)?;
153
154        Ok(timestamp)
155    }
156}
157
158impl Visitor<'_> for GoogleProtobufDurationVisitor {
159    type Value = prost_types::Duration;
160
161    fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
162        write!(f, "a duration string")
163    }
164
165    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
166    where
167        E: Error,
168    {
169        let duration: prost_types::Duration = v.parse().map_err(Error::custom)?;
170
171        check_duration(&duration).map_err(Error::custom)?;
172
173        Ok(duration)
174    }
175}
176
177impl Visitor<'_> for GoogleProtobufFieldMaskVisitor {
178    type Value = prost_types::FieldMask;
179
180    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
181    where
182        E: Error,
183    {
184        let paths = v
185            .split(',')
186            .filter(|path| !path.is_empty())
187            .map(|path| {
188                let mut result = String::new();
189                let mut parts = path.split('.');
190
191                if let Some(part) = parts.next() {
192                    camel_case_to_snake_case(&mut result, part)?;
193                }
194                for part in parts {
195                    result.push('.');
196                    camel_case_to_snake_case(&mut result, part)?;
197                }
198
199                Ok(result)
200            })
201            .collect::<Result<_, ()>>()
202            .map_err(|()| Error::custom("invalid field mask"))?;
203
204        Ok(prost_types::FieldMask { paths })
205    }
206
207    fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
208        write!(f, "a field mask string")
209    }
210}
211
212impl<'de> DeserializeSeed<'de> for GoogleProtobufValueVisitor {
213    type Value = prost_types::Value;
214
215    fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
216    where
217        D: Deserializer<'de>,
218    {
219        deserializer.deserialize_any(self)
220    }
221}
222
223impl<'de> Visitor<'de> for GoogleProtobufListVisitor {
224    type Value = prost_types::ListValue;
225
226    fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
227    where
228        A: SeqAccess<'de>,
229    {
230        let mut values = Vec::with_capacity(seq.size_hint().unwrap_or(0));
231        while let Some(value) = seq.next_element_seed(GoogleProtobufValueVisitor)? {
232            values.push(value);
233        }
234        Ok(prost_types::ListValue { values })
235    }
236
237    fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
238        write!(f, "a list")
239    }
240}
241
242impl<'de> Visitor<'de> for GoogleProtobufStructVisitor {
243    type Value = prost_types::Struct;
244
245    fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
246    where
247        A: MapAccess<'de>,
248    {
249        let mut fields = BTreeMap::new();
250        while let Some(key) = map.next_key::<String>()? {
251            let value = map.next_value_seed(GoogleProtobufValueVisitor)?;
252            fields.insert(key, value);
253        }
254        Ok(prost_types::Struct { fields })
255    }
256
257    fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
258        write!(f, "a map")
259    }
260}
261
262impl<'de> Visitor<'de> for GoogleProtobufValueVisitor {
263    type Value = prost_types::Value;
264
265    fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
266        write!(f, "a value")
267    }
268
269    fn visit_bool<E>(self, v: bool) -> Result<Self::Value, E>
270    where
271        E: Error,
272    {
273        Ok(prost_types::Value {
274            kind: Some(prost_types::value::Kind::BoolValue(v)),
275        })
276    }
277
278    fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
279    where
280        E: Error,
281    {
282        self.visit_f64(v as f64)
283    }
284
285    fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
286    where
287        E: Error,
288    {
289        self.visit_f64(v as f64)
290    }
291
292    fn visit_f64<E>(self, v: f64) -> Result<Self::Value, E>
293    where
294        E: Error,
295    {
296        Ok(prost_types::Value {
297            kind: Some(prost_types::value::Kind::NumberValue(v)),
298        })
299    }
300
301    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
302    where
303        E: Error,
304    {
305        self.visit_string(v.to_owned())
306    }
307
308    fn visit_string<E>(self, v: String) -> Result<Self::Value, E>
309    where
310        E: Error,
311    {
312        Ok(prost_types::Value {
313            kind: Some(prost_types::value::Kind::StringValue(v)),
314        })
315    }
316
317    #[inline]
318    fn visit_unit<E>(self) -> Result<Self::Value, E>
319    where
320        E: Error,
321    {
322        Ok(prost_types::Value {
323            kind: Some(prost_types::value::Kind::NullValue(0)),
324        })
325    }
326
327    fn visit_seq<A>(self, seq: A) -> Result<Self::Value, A::Error>
328    where
329        A: SeqAccess<'de>,
330    {
331        GoogleProtobufListVisitor
332            .visit_seq(seq)
333            .map(|l| prost_types::Value {
334                kind: Some(prost_types::value::Kind::ListValue(l)),
335            })
336    }
337
338    fn visit_map<A>(self, map: A) -> Result<Self::Value, A::Error>
339    where
340        A: MapAccess<'de>,
341    {
342        GoogleProtobufStructVisitor
343            .visit_map(map)
344            .map(|s| prost_types::Value {
345                kind: Some(prost_types::value::Kind::StructValue(s)),
346            })
347    }
348}
349
350impl<'de> Visitor<'de> for GoogleProtobufEmptyVisitor {
351    type Value = ();
352
353    fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
354    where
355        A: MapAccess<'de>,
356    {
357        if map.next_entry::<IgnoredAny, IgnoredAny>()?.is_some() {
358            return Err(Error::custom("unexpected value in map"));
359        }
360
361        Ok(())
362    }
363
364    fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
365        write!(f, "an empty map")
366    }
367}
368
369fn find_field<'de, A, D>(
370    map: &mut A,
371    buffered_entries: &mut HashMap<Cow<str>, serde_value::Value>,
372    expected: &str,
373    value_seed: D,
374) -> Result<D::Value, A::Error>
375where
376    A: MapAccess<'de>,
377    D: DeserializeSeed<'de>,
378{
379    loop {
380        match map.next_key::<Cow<str>>()? {
381            Some(key) if key == expected => return map.next_value_seed(value_seed),
382            Some(key) => {
383                buffered_entries.insert(key, map.next_value()?);
384            }
385            None => return Err(Error::custom(format!("expected '{expected}' field"))),
386        }
387    }
388}
389
390/// Validates the string is a valid RFC3339 timestamp, requiring upper-case
391/// 'T' and 'Z' characters as recommended by the conformance tests.
392fn validate_strict_rfc3339(v: &str) -> Result<(), String> {
393    use std::{ascii, iter::Peekable, str::Bytes};
394
395    fn pop_digit(bytes: &mut Peekable<Bytes>) -> bool {
396        bytes.next_if(u8::is_ascii_digit).is_some()
397    }
398
399    fn pop_digits(bytes: &mut Peekable<Bytes>, n: usize) -> bool {
400        (0..n).all(|_| pop_digit(bytes))
401    }
402
403    fn pop_char(p: &mut Peekable<Bytes>, c: u8) -> bool {
404        p.next_if_eq(&c).is_some()
405    }
406
407    fn fmt_next(p: &mut Peekable<Bytes>) -> String {
408        match p.peek() {
409            Some(&ch) => format!("'{}'", ascii::escape_default(ch)),
410            None => "end of string".to_owned(),
411        }
412    }
413
414    let mut v = v.bytes().peekable();
415
416    if !(pop_digits(&mut v, 4)
417        && pop_char(&mut v, b'-')
418        && pop_digits(&mut v, 2)
419        && pop_char(&mut v, b'-')
420        && pop_digits(&mut v, 2))
421    {
422        return Err("invalid rfc3339 timestamp: invalid date".to_owned());
423    }
424
425    if !pop_char(&mut v, b'T') {
426        return Err(format!(
427            "invalid rfc3339 timestamp: expected 'T' but found {}",
428            fmt_next(&mut v)
429        ));
430    }
431
432    if !(pop_digits(&mut v, 2)
433        && pop_char(&mut v, b':')
434        && pop_digits(&mut v, 2)
435        && pop_char(&mut v, b':')
436        && pop_digits(&mut v, 2))
437    {
438        return Err("invalid rfc3339 timestamp: invalid time".to_owned());
439    }
440
441    if pop_char(&mut v, b'.') {
442        if !pop_digit(&mut v) {
443            return Err("invalid rfc3339 timestamp: empty fractional seconds".to_owned());
444        }
445        while pop_digit(&mut v) {}
446    }
447
448    if v.next_if(|&ch| ch == b'+' || ch == b'-').is_some() {
449        if !(pop_digits(&mut v, 2) && pop_char(&mut v, b':') && pop_digits(&mut v, 2)) {
450            return Err("invalid rfc3339 timestamp: invalid offset".to_owned());
451        }
452    } else if !pop_char(&mut v, b'Z') {
453        return Err(format!(
454            "invalid rfc3339 timestamp: expected 'Z', '+' or '-' but found {}",
455            fmt_next(&mut v)
456        ));
457    }
458
459    if v.peek().is_some() {
460        return Err(format!(
461            "invalid rfc3339 timestamp: expected end of string but found {}",
462            fmt_next(&mut v)
463        ));
464    }
465
466    Ok(())
467}
468
469#[cfg(test)]
470mod tests {
471    use super::*;
472
473    #[test]
474    fn test_validate_strict_rfc3339() {
475        macro_rules! case {
476            ($s:expr => Ok) => {
477                assert_eq!(validate_strict_rfc3339($s), Ok(()))
478            };
479            ($s:expr => Err($e:expr)) => {
480                assert_eq!(validate_strict_rfc3339($s).unwrap_err().to_string(), $e)
481            };
482        }
483
484        case!("1972-06-30T23:59:60Z" => Ok);
485        case!("2019-03-26T14:00:00.9Z" => Ok);
486        case!("2019-03-26T14:00:00.4999Z" => Ok);
487        case!("2019-03-26T14:00:00.4999+10:00" => Ok);
488        case!("2019-03-26t14:00Z" => Err("invalid rfc3339 timestamp: expected 'T' but found 't'"));
489        case!("2019-03-26T14:00z" => Err("invalid rfc3339 timestamp: invalid time"));
490        case!("2019-03-26T14:00:00,999Z" => Err("invalid rfc3339 timestamp: expected 'Z', '+' or '-' but found ','"));
491        case!("2019-03-26T10:00-04" => Err("invalid rfc3339 timestamp: invalid time"));
492        case!("2019-03-26T14:00.9Z" => Err("invalid rfc3339 timestamp: invalid time"));
493        case!("20190326T1400Z" => Err("invalid rfc3339 timestamp: invalid date"));
494        case!("2019-02-30" => Err("invalid rfc3339 timestamp: expected 'T' but found end of string"));
495        case!("2019-03-25T24:01Z" => Err("invalid rfc3339 timestamp: invalid time"));
496        case!("2019-03-26T14:00+24:00" => Err("invalid rfc3339 timestamp: invalid time"));
497        case!("2019-03-26Z" => Err("invalid rfc3339 timestamp: expected 'T' but found 'Z'"));
498        case!("2019-03-26+01:00" => Err("invalid rfc3339 timestamp: expected 'T' but found '+'"));
499        case!("2019-03-26-04:00" => Err("invalid rfc3339 timestamp: expected 'T' but found '-'"));
500        case!("2019-03-26T10:00-0400" => Err("invalid rfc3339 timestamp: invalid time"));
501        case!("+0002019-03-26T14:00Z" => Err("invalid rfc3339 timestamp: invalid date"));
502        case!("+2019-03-26T14:00Z" => Err("invalid rfc3339 timestamp: invalid date"));
503        case!("002019-03-26T14:00Z" => Err("invalid rfc3339 timestamp: invalid date"));
504        case!("019-03-26T14:00Z" => Err("invalid rfc3339 timestamp: invalid date"));
505        case!("2019-03-26T10:00Q" => Err("invalid rfc3339 timestamp: invalid time"));
506        case!("2019-03-26T10:00T" => Err("invalid rfc3339 timestamp: invalid time"));
507        case!("2019-03-26Q" => Err("invalid rfc3339 timestamp: expected 'T' but found 'Q'"));
508        case!("2019-03-26T" => Err("invalid rfc3339 timestamp: invalid time"));
509        case!("2019-03-26 14:00Z" => Err("invalid rfc3339 timestamp: expected 'T' but found ' '"));
510        case!("2019-03-26T14:00:00." => Err("invalid rfc3339 timestamp: empty fractional seconds"));
511    }
512}