1use base64::{display::Base64Display, prelude::BASE64_STANDARD};
2use prost::{DecodeError, Message};
3use serde::ser::{Error, Serialize, SerializeMap, SerializeSeq, Serializer};
4
5use crate::{
6 dynamic::{
7 get_type_url_message_name,
8 serde::{
9 case::snake_case_to_camel_case, check_duration, check_timestamp, is_well_known_type,
10 SerializeOptions,
11 },
12 DynamicMessage,
13 },
14 ReflectMessage,
15};
16
17use super::{serialize_dynamic_message_fields, SerializeWrapper};
18
19#[allow(type_alias_bounds)]
20type WellKnownTypeSerializer<S: Serializer> =
21 fn(&DynamicMessage, S, &SerializeOptions) -> Result<S::Ok, S::Error>;
22
23pub fn get_well_known_type_serializer<S>(full_name: &str) -> Option<WellKnownTypeSerializer<S>>
24where
25 S: Serializer,
26{
27 match full_name {
28 "google.protobuf.Any" => Some(serialize_any),
29 "google.protobuf.Timestamp" => Some(serialize_timestamp),
30 "google.protobuf.Duration" => Some(serialize_duration),
31 "google.protobuf.Struct" => Some(serialize_struct),
32 "google.protobuf.FloatValue" => Some(serialize_float),
33 "google.protobuf.DoubleValue" => Some(serialize_double),
34 "google.protobuf.Int32Value" => Some(serialize_int32),
35 "google.protobuf.Int64Value" => Some(serialize_int64),
36 "google.protobuf.UInt32Value" => Some(serialize_uint32),
37 "google.protobuf.UInt64Value" => Some(serialize_uint64),
38 "google.protobuf.BoolValue" => Some(serialize_bool),
39 "google.protobuf.StringValue" => Some(serialize_string),
40 "google.protobuf.BytesValue" => Some(serialize_bytes),
41 "google.protobuf.FieldMask" => Some(serialize_field_mask),
42 "google.protobuf.ListValue" => Some(serialize_list),
43 "google.protobuf.Value" => Some(serialize_value),
44 "google.protobuf.Empty" => Some(serialize_empty),
45 _ => {
46 debug_assert!(!is_well_known_type(full_name));
47 None
48 }
49 }
50}
51
52fn serialize_any<S>(
53 msg: &DynamicMessage,
54 serializer: S,
55 options: &SerializeOptions,
56) -> Result<S::Ok, S::Error>
57where
58 S: Serializer,
59{
60 let raw: prost_types::Any = msg.transcode_to().map_err(decode_to_ser_err)?;
61
62 let message_name = get_type_url_message_name(&raw.type_url).map_err(Error::custom)?;
63
64 let message_desc = msg
65 .descriptor()
66 .parent_pool()
67 .get_message_by_name(message_name)
68 .ok_or_else(|| Error::custom(format!("message '{}' not found", message_name)))?;
69
70 let mut payload_message = DynamicMessage::new(message_desc);
71 payload_message
72 .merge(raw.value.as_ref())
73 .map_err(decode_to_ser_err)?;
74
75 if is_well_known_type(message_name) {
76 let mut map = serializer.serialize_map(Some(2))?;
77 map.serialize_entry("@type", &raw.type_url)?;
78 map.serialize_entry(
79 "value",
80 &SerializeWrapper {
81 value: &payload_message,
82 options,
83 },
84 )?;
85 map.end()
86 } else {
87 let mut map = serializer.serialize_map(None)?;
88 map.serialize_entry("@type", &raw.type_url)?;
89 serialize_dynamic_message_fields(&mut map, &payload_message, options)?;
90 map.end()
91 }
92}
93
94fn serialize_timestamp<S>(
95 msg: &DynamicMessage,
96 serializer: S,
97 _options: &SerializeOptions,
98) -> Result<S::Ok, S::Error>
99where
100 S: Serializer,
101{
102 let timestamp: prost_types::Timestamp = msg.transcode_to().map_err(decode_to_ser_err)?;
103
104 check_timestamp(×tamp).map_err(Error::custom)?;
105
106 serializer.collect_str(×tamp)
107}
108
109fn serialize_duration<S>(
110 msg: &DynamicMessage,
111 serializer: S,
112 _options: &SerializeOptions,
113) -> Result<S::Ok, S::Error>
114where
115 S: Serializer,
116{
117 let duration: prost_types::Duration = msg.transcode_to().map_err(decode_to_ser_err)?;
118
119 check_duration(&duration).map_err(Error::custom)?;
120
121 serializer.collect_str(&duration)
122}
123
124fn serialize_float<S>(
125 msg: &DynamicMessage,
126 serializer: S,
127 _options: &SerializeOptions,
128) -> Result<S::Ok, S::Error>
129where
130 S: Serializer,
131{
132 let raw: f32 = msg.transcode_to().map_err(decode_to_ser_err)?;
133
134 serializer.serialize_f32(raw)
135}
136
137fn serialize_double<S>(
138 msg: &DynamicMessage,
139 serializer: S,
140 _options: &SerializeOptions,
141) -> Result<S::Ok, S::Error>
142where
143 S: Serializer,
144{
145 let raw: f64 = msg.transcode_to().map_err(decode_to_ser_err)?;
146
147 serializer.serialize_f64(raw)
148}
149
150fn serialize_int32<S>(
151 msg: &DynamicMessage,
152 serializer: S,
153 _options: &SerializeOptions,
154) -> Result<S::Ok, S::Error>
155where
156 S: Serializer,
157{
158 let raw: i32 = msg.transcode_to().map_err(decode_to_ser_err)?;
159
160 serializer.serialize_i32(raw)
161}
162
163fn serialize_int64<S>(
164 msg: &DynamicMessage,
165 serializer: S,
166 options: &SerializeOptions,
167) -> Result<S::Ok, S::Error>
168where
169 S: Serializer,
170{
171 let raw: i64 = msg.transcode_to().map_err(decode_to_ser_err)?;
172
173 if options.stringify_64_bit_integers {
174 serializer.collect_str(&raw)
175 } else {
176 serializer.serialize_i64(raw)
177 }
178}
179
180fn serialize_uint32<S>(
181 msg: &DynamicMessage,
182 serializer: S,
183 _options: &SerializeOptions,
184) -> Result<S::Ok, S::Error>
185where
186 S: Serializer,
187{
188 let raw: u32 = msg.transcode_to().map_err(decode_to_ser_err)?;
189
190 serializer.serialize_u32(raw)
191}
192
193fn serialize_uint64<S>(
194 msg: &DynamicMessage,
195 serializer: S,
196 options: &SerializeOptions,
197) -> Result<S::Ok, S::Error>
198where
199 S: Serializer,
200{
201 let raw: u64 = msg.transcode_to().map_err(decode_to_ser_err)?;
202
203 if options.stringify_64_bit_integers {
204 serializer.collect_str(&raw)
205 } else {
206 serializer.serialize_u64(raw)
207 }
208}
209
210fn serialize_bool<S>(
211 msg: &DynamicMessage,
212 serializer: S,
213 _options: &SerializeOptions,
214) -> Result<S::Ok, S::Error>
215where
216 S: Serializer,
217{
218 let raw: bool = msg.transcode_to().map_err(decode_to_ser_err)?;
219
220 serializer.serialize_bool(raw)
221}
222
223fn serialize_string<S>(
224 msg: &DynamicMessage,
225 serializer: S,
226 _options: &SerializeOptions,
227) -> Result<S::Ok, S::Error>
228where
229 S: Serializer,
230{
231 let raw: String = msg.transcode_to().map_err(decode_to_ser_err)?;
232
233 serializer.serialize_str(&raw)
234}
235
236fn serialize_bytes<S>(
237 msg: &DynamicMessage,
238 serializer: S,
239 _options: &SerializeOptions,
240) -> Result<S::Ok, S::Error>
241where
242 S: Serializer,
243{
244 let raw: Vec<u8> = msg.transcode_to().map_err(decode_to_ser_err)?;
245
246 serializer.collect_str(&Base64Display::new(&raw, &BASE64_STANDARD))
247}
248
249fn serialize_field_mask<S>(
250 msg: &DynamicMessage,
251 serializer: S,
252 _options: &SerializeOptions,
253) -> Result<S::Ok, S::Error>
254where
255 S: Serializer,
256{
257 let raw: prost_types::FieldMask = msg.transcode_to().map_err(decode_to_ser_err)?;
258
259 let mut result = String::new();
260 for path in raw.paths {
261 if !result.is_empty() {
262 result.push(',');
263 }
264
265 let mut first = true;
266 for part in path.split('.') {
267 if !first {
268 result.push('.');
269 }
270 snake_case_to_camel_case(&mut result, part)
271 .map_err(|()| Error::custom("cannot roundtrip field name through camelcase"))?;
272 first = false;
273 }
274 }
275
276 serializer.serialize_str(&result)
277}
278
279fn serialize_empty<S>(
280 _: &DynamicMessage,
281 serializer: S,
282 _options: &SerializeOptions,
283) -> Result<S::Ok, S::Error>
284where
285 S: Serializer,
286{
287 serializer.collect_map(std::iter::empty::<((), ())>())
288}
289
290fn serialize_value<S>(
291 msg: &DynamicMessage,
292 serializer: S,
293 options: &SerializeOptions,
294) -> Result<S::Ok, S::Error>
295where
296 S: Serializer,
297{
298 let raw: prost_types::Value = msg.transcode_to().map_err(decode_to_ser_err)?;
299
300 serialize_value_inner(&raw, serializer, options)
301}
302
303fn serialize_struct<S>(
304 msg: &DynamicMessage,
305 serializer: S,
306 options: &SerializeOptions,
307) -> Result<S::Ok, S::Error>
308where
309 S: Serializer,
310{
311 let raw: prost_types::Struct = msg.transcode_to().map_err(decode_to_ser_err)?;
312
313 serialize_struct_inner(&raw, serializer, options)
314}
315
316fn serialize_list<S>(
317 msg: &DynamicMessage,
318 serializer: S,
319 options: &SerializeOptions,
320) -> Result<S::Ok, S::Error>
321where
322 S: Serializer,
323{
324 let raw: prost_types::ListValue = msg.transcode_to().map_err(decode_to_ser_err)?;
325
326 serialize_list_inner(&raw, serializer, options)
327}
328
329impl Serialize for SerializeWrapper<'_, prost_types::Value> {
330 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
331 where
332 S: Serializer,
333 {
334 serialize_value_inner(self.value, serializer, self.options)
335 }
336}
337
338fn serialize_value_inner<S>(
339 raw: &prost_types::Value,
340 serializer: S,
341 options: &SerializeOptions,
342) -> Result<S::Ok, S::Error>
343where
344 S: Serializer,
345{
346 match &raw.kind {
347 None | Some(prost_types::value::Kind::NullValue(_)) => serializer.serialize_none(),
348 Some(prost_types::value::Kind::BoolValue(value)) => serializer.serialize_bool(*value),
349 Some(prost_types::value::Kind::NumberValue(number)) => {
350 if number.is_finite() {
351 serializer.serialize_f64(*number)
352 } else {
353 Err(Error::custom(
354 "cannot serialize non-finite double in google.protobuf.Value",
355 ))
356 }
357 }
358 Some(prost_types::value::Kind::StringValue(value)) => serializer.serialize_str(value),
359 Some(prost_types::value::Kind::ListValue(value)) => {
360 serialize_list_inner(value, serializer, options)
361 }
362 Some(prost_types::value::Kind::StructValue(value)) => {
363 serialize_struct_inner(value, serializer, options)
364 }
365 }
366}
367
368fn serialize_struct_inner<S>(
369 raw: &prost_types::Struct,
370 serializer: S,
371 options: &SerializeOptions,
372) -> Result<S::Ok, S::Error>
373where
374 S: Serializer,
375{
376 let mut map = serializer.serialize_map(Some(raw.fields.len()))?;
377 for (key, value) in &raw.fields {
378 map.serialize_entry(key, &SerializeWrapper { value, options })?;
379 }
380 map.end()
381}
382
383fn serialize_list_inner<S>(
384 raw: &prost_types::ListValue,
385 serializer: S,
386 options: &SerializeOptions,
387) -> Result<S::Ok, S::Error>
388where
389 S: Serializer,
390{
391 let mut list = serializer.serialize_seq(Some(raw.values.len()))?;
392 for value in &raw.values {
393 list.serialize_element(&SerializeWrapper { value, options })?;
394 }
395 list.end()
396}
397
398fn decode_to_ser_err<E>(err: DecodeError) -> E
399where
400 E: Error,
401{
402 Error::custom(format!("error decoding: {}", err))
403}