1use std::{
2 borrow::Cow,
3 collections::{BTreeMap, HashMap},
4 fmt,
5 marker::PhantomData,
6};
7
8use prost::Message;
9use serde::de::{
10 DeserializeSeed, Deserializer, Error, IgnoredAny, IntoDeserializer, MapAccess, SeqAccess,
11 Visitor,
12};
13
14use crate::{
15 dynamic::{
16 get_type_url_message_name,
17 serde::{
18 case::camel_case_to_snake_case, check_duration, check_timestamp, is_well_known_type,
19 DeserializeOptions,
20 },
21 DynamicMessage,
22 },
23 DescriptorPool,
24};
25
26use super::{deserialize_message, kind::MessageVisitorInner, MessageSeed};
27
28pub struct GoogleProtobufAnyVisitor<'a>(pub &'a DescriptorPool, pub &'a DeserializeOptions);
29pub struct GoogleProtobufNullVisitor<'a>(pub &'a DeserializeOptions);
30pub struct GoogleProtobufTimestampVisitor;
31pub struct GoogleProtobufDurationVisitor;
32pub struct GoogleProtobufFieldMaskVisitor;
33pub struct GoogleProtobufListVisitor;
34pub struct GoogleProtobufStructVisitor;
35pub struct GoogleProtobufValueVisitor;
36pub struct GoogleProtobufEmptyVisitor;
37
38impl<'de> Visitor<'de> for GoogleProtobufAnyVisitor<'_> {
39 type Value = prost_types::Any;
40
41 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
42 write!(f, "a map")
43 }
44
45 fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
46 where
47 A: MapAccess<'de>,
48 {
49 let mut buffered_entries = HashMap::new();
50
51 let type_url = find_field(
52 &mut map,
53 &mut buffered_entries,
54 "@type",
55 PhantomData::<String>,
56 )?;
57
58 let message_name = get_type_url_message_name(&type_url).map_err(Error::custom)?;
59 let message_desc = self
60 .0
61 .get_message_by_name(message_name)
62 .ok_or_else(|| Error::custom(format!("message '{}' not found", message_name)))?;
63
64 let payload_message = if is_well_known_type(message_name) {
65 let payload_message = match buffered_entries.remove("value") {
66 Some(value) => {
67 deserialize_message(&message_desc, value, self.1).map_err(Error::custom)?
68 }
69 None => find_field(
70 &mut map,
71 &mut buffered_entries,
72 "value",
73 MessageSeed(&message_desc, self.1),
74 )?,
75 };
76
77 if self.1.deny_unknown_fields {
78 if let Some(key) = buffered_entries.keys().next() {
79 return Err(Error::custom(format!("unrecognized field name '{}'", key)));
80 }
81 if let Some(key) = map.next_key::<Cow<str>>()? {
82 return Err(Error::custom(format!("unrecognized field name '{}'", key)));
83 }
84 } else {
85 drop(buffered_entries);
86 while map.next_entry::<IgnoredAny, IgnoredAny>()?.is_some() {}
87 }
88
89 payload_message
90 } else {
91 let mut payload_message = DynamicMessage::new(message_desc);
92
93 buffered_entries
94 .into_deserializer()
95 .deserialize_map(MessageVisitorInner(&mut payload_message, self.1))
96 .map_err(Error::custom)?;
97
98 MessageVisitorInner(&mut payload_message, self.1).visit_map(map)?;
99
100 payload_message
101 };
102
103 let value = payload_message.encode_to_vec();
104 Ok(prost_types::Any { type_url, value })
105 }
106}
107
108impl Visitor<'_> for GoogleProtobufNullVisitor<'_> {
109 type Value = Option<i32>;
110
111 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
112 write!(f, "null")
113 }
114
115 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
116 where
117 E: Error,
118 {
119 if v == "NULL_VALUE" {
120 Ok(Some(0))
121 } else if self.0.deny_unknown_fields {
122 Err(Error::custom("expected null"))
123 } else {
124 Ok(None)
125 }
126 }
127
128 #[inline]
129 fn visit_unit<E>(self) -> Result<Self::Value, E>
130 where
131 E: Error,
132 {
133 Ok(Some(0))
134 }
135}
136
137impl Visitor<'_> for GoogleProtobufTimestampVisitor {
138 type Value = prost_types::Timestamp;
139
140 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
141 write!(f, "a rfc3339 timestamp string")
142 }
143
144 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
145 where
146 E: Error,
147 {
148 validate_strict_rfc3339(v).map_err(Error::custom)?;
149
150 let timestamp: prost_types::Timestamp = v.parse().map_err(Error::custom)?;
151
152 check_timestamp(×tamp).map_err(Error::custom)?;
153
154 Ok(timestamp)
155 }
156}
157
158impl Visitor<'_> for GoogleProtobufDurationVisitor {
159 type Value = prost_types::Duration;
160
161 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
162 write!(f, "a duration string")
163 }
164
165 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
166 where
167 E: Error,
168 {
169 let duration: prost_types::Duration = v.parse().map_err(Error::custom)?;
170
171 check_duration(&duration).map_err(Error::custom)?;
172
173 Ok(duration)
174 }
175}
176
177impl Visitor<'_> for GoogleProtobufFieldMaskVisitor {
178 type Value = prost_types::FieldMask;
179
180 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
181 where
182 E: Error,
183 {
184 let paths = v
185 .split(',')
186 .filter(|path| !path.is_empty())
187 .map(|path| {
188 let mut result = String::new();
189 let mut parts = path.split('.');
190
191 if let Some(part) = parts.next() {
192 camel_case_to_snake_case(&mut result, part)?;
193 }
194 for part in parts {
195 result.push('.');
196 camel_case_to_snake_case(&mut result, part)?;
197 }
198
199 Ok(result)
200 })
201 .collect::<Result<_, ()>>()
202 .map_err(|()| Error::custom("invalid field mask"))?;
203
204 Ok(prost_types::FieldMask { paths })
205 }
206
207 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
208 write!(f, "a field mask string")
209 }
210}
211
212impl<'de> DeserializeSeed<'de> for GoogleProtobufValueVisitor {
213 type Value = prost_types::Value;
214
215 fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
216 where
217 D: Deserializer<'de>,
218 {
219 deserializer.deserialize_any(self)
220 }
221}
222
223impl<'de> Visitor<'de> for GoogleProtobufListVisitor {
224 type Value = prost_types::ListValue;
225
226 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
227 where
228 A: SeqAccess<'de>,
229 {
230 let mut values = Vec::with_capacity(seq.size_hint().unwrap_or(0));
231 while let Some(value) = seq.next_element_seed(GoogleProtobufValueVisitor)? {
232 values.push(value);
233 }
234 Ok(prost_types::ListValue { values })
235 }
236
237 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
238 write!(f, "a list")
239 }
240}
241
242impl<'de> Visitor<'de> for GoogleProtobufStructVisitor {
243 type Value = prost_types::Struct;
244
245 fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
246 where
247 A: MapAccess<'de>,
248 {
249 let mut fields = BTreeMap::new();
250 while let Some(key) = map.next_key::<String>()? {
251 let value = map.next_value_seed(GoogleProtobufValueVisitor)?;
252 fields.insert(key, value);
253 }
254 Ok(prost_types::Struct { fields })
255 }
256
257 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
258 write!(f, "a map")
259 }
260}
261
262impl<'de> Visitor<'de> for GoogleProtobufValueVisitor {
263 type Value = prost_types::Value;
264
265 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
266 write!(f, "a value")
267 }
268
269 fn visit_bool<E>(self, v: bool) -> Result<Self::Value, E>
270 where
271 E: Error,
272 {
273 Ok(prost_types::Value {
274 kind: Some(prost_types::value::Kind::BoolValue(v)),
275 })
276 }
277
278 fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
279 where
280 E: Error,
281 {
282 self.visit_f64(v as f64)
283 }
284
285 fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
286 where
287 E: Error,
288 {
289 self.visit_f64(v as f64)
290 }
291
292 fn visit_f64<E>(self, v: f64) -> Result<Self::Value, E>
293 where
294 E: Error,
295 {
296 Ok(prost_types::Value {
297 kind: Some(prost_types::value::Kind::NumberValue(v)),
298 })
299 }
300
301 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
302 where
303 E: Error,
304 {
305 self.visit_string(v.to_owned())
306 }
307
308 fn visit_string<E>(self, v: String) -> Result<Self::Value, E>
309 where
310 E: Error,
311 {
312 Ok(prost_types::Value {
313 kind: Some(prost_types::value::Kind::StringValue(v)),
314 })
315 }
316
317 #[inline]
318 fn visit_unit<E>(self) -> Result<Self::Value, E>
319 where
320 E: Error,
321 {
322 Ok(prost_types::Value {
323 kind: Some(prost_types::value::Kind::NullValue(0)),
324 })
325 }
326
327 fn visit_seq<A>(self, seq: A) -> Result<Self::Value, A::Error>
328 where
329 A: SeqAccess<'de>,
330 {
331 GoogleProtobufListVisitor
332 .visit_seq(seq)
333 .map(|l| prost_types::Value {
334 kind: Some(prost_types::value::Kind::ListValue(l)),
335 })
336 }
337
338 fn visit_map<A>(self, map: A) -> Result<Self::Value, A::Error>
339 where
340 A: MapAccess<'de>,
341 {
342 GoogleProtobufStructVisitor
343 .visit_map(map)
344 .map(|s| prost_types::Value {
345 kind: Some(prost_types::value::Kind::StructValue(s)),
346 })
347 }
348}
349
350impl<'de> Visitor<'de> for GoogleProtobufEmptyVisitor {
351 type Value = ();
352
353 fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
354 where
355 A: MapAccess<'de>,
356 {
357 if map.next_entry::<IgnoredAny, IgnoredAny>()?.is_some() {
358 return Err(Error::custom("unexpected value in map"));
359 }
360
361 Ok(())
362 }
363
364 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
365 write!(f, "an empty map")
366 }
367}
368
369fn find_field<'de, A, D>(
370 map: &mut A,
371 buffered_entries: &mut HashMap<Cow<str>, serde_value::Value>,
372 expected: &str,
373 value_seed: D,
374) -> Result<D::Value, A::Error>
375where
376 A: MapAccess<'de>,
377 D: DeserializeSeed<'de>,
378{
379 loop {
380 match map.next_key::<Cow<str>>()? {
381 Some(key) if key == expected => return map.next_value_seed(value_seed),
382 Some(key) => {
383 buffered_entries.insert(key, map.next_value()?);
384 }
385 None => return Err(Error::custom(format!("expected '{expected}' field"))),
386 }
387 }
388}
389
390fn validate_strict_rfc3339(v: &str) -> Result<(), String> {
393 use std::{ascii, iter::Peekable, str::Bytes};
394
395 fn pop_digit(bytes: &mut Peekable<Bytes>) -> bool {
396 bytes.next_if(u8::is_ascii_digit).is_some()
397 }
398
399 fn pop_digits(bytes: &mut Peekable<Bytes>, n: usize) -> bool {
400 (0..n).all(|_| pop_digit(bytes))
401 }
402
403 fn pop_char(p: &mut Peekable<Bytes>, c: u8) -> bool {
404 p.next_if_eq(&c).is_some()
405 }
406
407 fn fmt_next(p: &mut Peekable<Bytes>) -> String {
408 match p.peek() {
409 Some(&ch) => format!("'{}'", ascii::escape_default(ch)),
410 None => "end of string".to_owned(),
411 }
412 }
413
414 let mut v = v.bytes().peekable();
415
416 if !(pop_digits(&mut v, 4)
417 && pop_char(&mut v, b'-')
418 && pop_digits(&mut v, 2)
419 && pop_char(&mut v, b'-')
420 && pop_digits(&mut v, 2))
421 {
422 return Err("invalid rfc3339 timestamp: invalid date".to_owned());
423 }
424
425 if !pop_char(&mut v, b'T') {
426 return Err(format!(
427 "invalid rfc3339 timestamp: expected 'T' but found {}",
428 fmt_next(&mut v)
429 ));
430 }
431
432 if !(pop_digits(&mut v, 2)
433 && pop_char(&mut v, b':')
434 && pop_digits(&mut v, 2)
435 && pop_char(&mut v, b':')
436 && pop_digits(&mut v, 2))
437 {
438 return Err("invalid rfc3339 timestamp: invalid time".to_owned());
439 }
440
441 if pop_char(&mut v, b'.') {
442 if !pop_digit(&mut v) {
443 return Err("invalid rfc3339 timestamp: empty fractional seconds".to_owned());
444 }
445 while pop_digit(&mut v) {}
446 }
447
448 if v.next_if(|&ch| ch == b'+' || ch == b'-').is_some() {
449 if !(pop_digits(&mut v, 2) && pop_char(&mut v, b':') && pop_digits(&mut v, 2)) {
450 return Err("invalid rfc3339 timestamp: invalid offset".to_owned());
451 }
452 } else if !pop_char(&mut v, b'Z') {
453 return Err(format!(
454 "invalid rfc3339 timestamp: expected 'Z', '+' or '-' but found {}",
455 fmt_next(&mut v)
456 ));
457 }
458
459 if v.peek().is_some() {
460 return Err(format!(
461 "invalid rfc3339 timestamp: expected end of string but found {}",
462 fmt_next(&mut v)
463 ));
464 }
465
466 Ok(())
467}
468
469#[cfg(test)]
470mod tests {
471 use super::*;
472
473 #[test]
474 fn test_validate_strict_rfc3339() {
475 macro_rules! case {
476 ($s:expr => Ok) => {
477 assert_eq!(validate_strict_rfc3339($s), Ok(()))
478 };
479 ($s:expr => Err($e:expr)) => {
480 assert_eq!(validate_strict_rfc3339($s).unwrap_err().to_string(), $e)
481 };
482 }
483
484 case!("1972-06-30T23:59:60Z" => Ok);
485 case!("2019-03-26T14:00:00.9Z" => Ok);
486 case!("2019-03-26T14:00:00.4999Z" => Ok);
487 case!("2019-03-26T14:00:00.4999+10:00" => Ok);
488 case!("2019-03-26t14:00Z" => Err("invalid rfc3339 timestamp: expected 'T' but found 't'"));
489 case!("2019-03-26T14:00z" => Err("invalid rfc3339 timestamp: invalid time"));
490 case!("2019-03-26T14:00:00,999Z" => Err("invalid rfc3339 timestamp: expected 'Z', '+' or '-' but found ','"));
491 case!("2019-03-26T10:00-04" => Err("invalid rfc3339 timestamp: invalid time"));
492 case!("2019-03-26T14:00.9Z" => Err("invalid rfc3339 timestamp: invalid time"));
493 case!("20190326T1400Z" => Err("invalid rfc3339 timestamp: invalid date"));
494 case!("2019-02-30" => Err("invalid rfc3339 timestamp: expected 'T' but found end of string"));
495 case!("2019-03-25T24:01Z" => Err("invalid rfc3339 timestamp: invalid time"));
496 case!("2019-03-26T14:00+24:00" => Err("invalid rfc3339 timestamp: invalid time"));
497 case!("2019-03-26Z" => Err("invalid rfc3339 timestamp: expected 'T' but found 'Z'"));
498 case!("2019-03-26+01:00" => Err("invalid rfc3339 timestamp: expected 'T' but found '+'"));
499 case!("2019-03-26-04:00" => Err("invalid rfc3339 timestamp: expected 'T' but found '-'"));
500 case!("2019-03-26T10:00-0400" => Err("invalid rfc3339 timestamp: invalid time"));
501 case!("+0002019-03-26T14:00Z" => Err("invalid rfc3339 timestamp: invalid date"));
502 case!("+2019-03-26T14:00Z" => Err("invalid rfc3339 timestamp: invalid date"));
503 case!("002019-03-26T14:00Z" => Err("invalid rfc3339 timestamp: invalid date"));
504 case!("019-03-26T14:00Z" => Err("invalid rfc3339 timestamp: invalid date"));
505 case!("2019-03-26T10:00Q" => Err("invalid rfc3339 timestamp: invalid time"));
506 case!("2019-03-26T10:00T" => Err("invalid rfc3339 timestamp: invalid time"));
507 case!("2019-03-26Q" => Err("invalid rfc3339 timestamp: expected 'T' but found 'Q'"));
508 case!("2019-03-26T" => Err("invalid rfc3339 timestamp: invalid time"));
509 case!("2019-03-26 14:00Z" => Err("invalid rfc3339 timestamp: expected 'T' but found ' '"));
510 case!("2019-03-26T14:00:00." => Err("invalid rfc3339 timestamp: empty fractional seconds"));
511 }
512}