prost_reflect/dynamic/serde/ser/
wkt.rs

1use base64::{display::Base64Display, prelude::BASE64_STANDARD};
2use prost::{DecodeError, Message};
3use serde::ser::{Error, Serialize, SerializeMap, SerializeSeq, Serializer};
4
5use crate::{
6    dynamic::{
7        get_type_url_message_name,
8        serde::{
9            case::snake_case_to_camel_case, check_duration, check_timestamp, is_well_known_type,
10            SerializeOptions,
11        },
12        DynamicMessage,
13    },
14    ReflectMessage,
15};
16
17use super::{serialize_dynamic_message_fields, SerializeWrapper};
18
19#[allow(type_alias_bounds)]
20type WellKnownTypeSerializer<S: Serializer> =
21    fn(&DynamicMessage, S, &SerializeOptions) -> Result<S::Ok, S::Error>;
22
23pub fn get_well_known_type_serializer<S>(full_name: &str) -> Option<WellKnownTypeSerializer<S>>
24where
25    S: Serializer,
26{
27    match full_name {
28        "google.protobuf.Any" => Some(serialize_any),
29        "google.protobuf.Timestamp" => Some(serialize_timestamp),
30        "google.protobuf.Duration" => Some(serialize_duration),
31        "google.protobuf.Struct" => Some(serialize_struct),
32        "google.protobuf.FloatValue" => Some(serialize_float),
33        "google.protobuf.DoubleValue" => Some(serialize_double),
34        "google.protobuf.Int32Value" => Some(serialize_int32),
35        "google.protobuf.Int64Value" => Some(serialize_int64),
36        "google.protobuf.UInt32Value" => Some(serialize_uint32),
37        "google.protobuf.UInt64Value" => Some(serialize_uint64),
38        "google.protobuf.BoolValue" => Some(serialize_bool),
39        "google.protobuf.StringValue" => Some(serialize_string),
40        "google.protobuf.BytesValue" => Some(serialize_bytes),
41        "google.protobuf.FieldMask" => Some(serialize_field_mask),
42        "google.protobuf.ListValue" => Some(serialize_list),
43        "google.protobuf.Value" => Some(serialize_value),
44        "google.protobuf.Empty" => Some(serialize_empty),
45        _ => {
46            debug_assert!(!is_well_known_type(full_name));
47            None
48        }
49    }
50}
51
52fn serialize_any<S>(
53    msg: &DynamicMessage,
54    serializer: S,
55    options: &SerializeOptions,
56) -> Result<S::Ok, S::Error>
57where
58    S: Serializer,
59{
60    let raw: prost_types::Any = msg.transcode_to().map_err(decode_to_ser_err)?;
61
62    let message_name = get_type_url_message_name(&raw.type_url).map_err(Error::custom)?;
63
64    let message_desc = msg
65        .descriptor()
66        .parent_pool()
67        .get_message_by_name(message_name)
68        .ok_or_else(|| Error::custom(format!("message '{}' not found", message_name)))?;
69
70    let mut payload_message = DynamicMessage::new(message_desc);
71    payload_message
72        .merge(raw.value.as_ref())
73        .map_err(decode_to_ser_err)?;
74
75    if is_well_known_type(message_name) {
76        let mut map = serializer.serialize_map(Some(2))?;
77        map.serialize_entry("@type", &raw.type_url)?;
78        map.serialize_entry(
79            "value",
80            &SerializeWrapper {
81                value: &payload_message,
82                options,
83            },
84        )?;
85        map.end()
86    } else {
87        let mut map = serializer.serialize_map(None)?;
88        map.serialize_entry("@type", &raw.type_url)?;
89        serialize_dynamic_message_fields(&mut map, &payload_message, options)?;
90        map.end()
91    }
92}
93
94fn serialize_timestamp<S>(
95    msg: &DynamicMessage,
96    serializer: S,
97    _options: &SerializeOptions,
98) -> Result<S::Ok, S::Error>
99where
100    S: Serializer,
101{
102    let timestamp: prost_types::Timestamp = msg.transcode_to().map_err(decode_to_ser_err)?;
103
104    check_timestamp(&timestamp).map_err(Error::custom)?;
105
106    serializer.collect_str(&timestamp)
107}
108
109fn serialize_duration<S>(
110    msg: &DynamicMessage,
111    serializer: S,
112    _options: &SerializeOptions,
113) -> Result<S::Ok, S::Error>
114where
115    S: Serializer,
116{
117    let duration: prost_types::Duration = msg.transcode_to().map_err(decode_to_ser_err)?;
118
119    check_duration(&duration).map_err(Error::custom)?;
120
121    serializer.collect_str(&duration)
122}
123
124fn serialize_float<S>(
125    msg: &DynamicMessage,
126    serializer: S,
127    _options: &SerializeOptions,
128) -> Result<S::Ok, S::Error>
129where
130    S: Serializer,
131{
132    let raw: f32 = msg.transcode_to().map_err(decode_to_ser_err)?;
133
134    serializer.serialize_f32(raw)
135}
136
137fn serialize_double<S>(
138    msg: &DynamicMessage,
139    serializer: S,
140    _options: &SerializeOptions,
141) -> Result<S::Ok, S::Error>
142where
143    S: Serializer,
144{
145    let raw: f64 = msg.transcode_to().map_err(decode_to_ser_err)?;
146
147    serializer.serialize_f64(raw)
148}
149
150fn serialize_int32<S>(
151    msg: &DynamicMessage,
152    serializer: S,
153    _options: &SerializeOptions,
154) -> Result<S::Ok, S::Error>
155where
156    S: Serializer,
157{
158    let raw: i32 = msg.transcode_to().map_err(decode_to_ser_err)?;
159
160    serializer.serialize_i32(raw)
161}
162
163fn serialize_int64<S>(
164    msg: &DynamicMessage,
165    serializer: S,
166    options: &SerializeOptions,
167) -> Result<S::Ok, S::Error>
168where
169    S: Serializer,
170{
171    let raw: i64 = msg.transcode_to().map_err(decode_to_ser_err)?;
172
173    if options.stringify_64_bit_integers {
174        serializer.collect_str(&raw)
175    } else {
176        serializer.serialize_i64(raw)
177    }
178}
179
180fn serialize_uint32<S>(
181    msg: &DynamicMessage,
182    serializer: S,
183    _options: &SerializeOptions,
184) -> Result<S::Ok, S::Error>
185where
186    S: Serializer,
187{
188    let raw: u32 = msg.transcode_to().map_err(decode_to_ser_err)?;
189
190    serializer.serialize_u32(raw)
191}
192
193fn serialize_uint64<S>(
194    msg: &DynamicMessage,
195    serializer: S,
196    options: &SerializeOptions,
197) -> Result<S::Ok, S::Error>
198where
199    S: Serializer,
200{
201    let raw: u64 = msg.transcode_to().map_err(decode_to_ser_err)?;
202
203    if options.stringify_64_bit_integers {
204        serializer.collect_str(&raw)
205    } else {
206        serializer.serialize_u64(raw)
207    }
208}
209
210fn serialize_bool<S>(
211    msg: &DynamicMessage,
212    serializer: S,
213    _options: &SerializeOptions,
214) -> Result<S::Ok, S::Error>
215where
216    S: Serializer,
217{
218    let raw: bool = msg.transcode_to().map_err(decode_to_ser_err)?;
219
220    serializer.serialize_bool(raw)
221}
222
223fn serialize_string<S>(
224    msg: &DynamicMessage,
225    serializer: S,
226    _options: &SerializeOptions,
227) -> Result<S::Ok, S::Error>
228where
229    S: Serializer,
230{
231    let raw: String = msg.transcode_to().map_err(decode_to_ser_err)?;
232
233    serializer.serialize_str(&raw)
234}
235
236fn serialize_bytes<S>(
237    msg: &DynamicMessage,
238    serializer: S,
239    _options: &SerializeOptions,
240) -> Result<S::Ok, S::Error>
241where
242    S: Serializer,
243{
244    let raw: Vec<u8> = msg.transcode_to().map_err(decode_to_ser_err)?;
245
246    serializer.collect_str(&Base64Display::new(&raw, &BASE64_STANDARD))
247}
248
249fn serialize_field_mask<S>(
250    msg: &DynamicMessage,
251    serializer: S,
252    _options: &SerializeOptions,
253) -> Result<S::Ok, S::Error>
254where
255    S: Serializer,
256{
257    let raw: prost_types::FieldMask = msg.transcode_to().map_err(decode_to_ser_err)?;
258
259    let mut result = String::new();
260    for path in raw.paths {
261        if !result.is_empty() {
262            result.push(',');
263        }
264
265        let mut first = true;
266        for part in path.split('.') {
267            if !first {
268                result.push('.');
269            }
270            snake_case_to_camel_case(&mut result, part)
271                .map_err(|()| Error::custom("cannot roundtrip field name through camelcase"))?;
272            first = false;
273        }
274    }
275
276    serializer.serialize_str(&result)
277}
278
279fn serialize_empty<S>(
280    _: &DynamicMessage,
281    serializer: S,
282    _options: &SerializeOptions,
283) -> Result<S::Ok, S::Error>
284where
285    S: Serializer,
286{
287    serializer.collect_map(std::iter::empty::<((), ())>())
288}
289
290fn serialize_value<S>(
291    msg: &DynamicMessage,
292    serializer: S,
293    options: &SerializeOptions,
294) -> Result<S::Ok, S::Error>
295where
296    S: Serializer,
297{
298    let raw: prost_types::Value = msg.transcode_to().map_err(decode_to_ser_err)?;
299
300    serialize_value_inner(&raw, serializer, options)
301}
302
303fn serialize_struct<S>(
304    msg: &DynamicMessage,
305    serializer: S,
306    options: &SerializeOptions,
307) -> Result<S::Ok, S::Error>
308where
309    S: Serializer,
310{
311    let raw: prost_types::Struct = msg.transcode_to().map_err(decode_to_ser_err)?;
312
313    serialize_struct_inner(&raw, serializer, options)
314}
315
316fn serialize_list<S>(
317    msg: &DynamicMessage,
318    serializer: S,
319    options: &SerializeOptions,
320) -> Result<S::Ok, S::Error>
321where
322    S: Serializer,
323{
324    let raw: prost_types::ListValue = msg.transcode_to().map_err(decode_to_ser_err)?;
325
326    serialize_list_inner(&raw, serializer, options)
327}
328
329impl Serialize for SerializeWrapper<'_, prost_types::Value> {
330    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
331    where
332        S: Serializer,
333    {
334        serialize_value_inner(self.value, serializer, self.options)
335    }
336}
337
338fn serialize_value_inner<S>(
339    raw: &prost_types::Value,
340    serializer: S,
341    options: &SerializeOptions,
342) -> Result<S::Ok, S::Error>
343where
344    S: Serializer,
345{
346    match &raw.kind {
347        None | Some(prost_types::value::Kind::NullValue(_)) => serializer.serialize_none(),
348        Some(prost_types::value::Kind::BoolValue(value)) => serializer.serialize_bool(*value),
349        Some(prost_types::value::Kind::NumberValue(number)) => {
350            if number.is_finite() {
351                serializer.serialize_f64(*number)
352            } else {
353                Err(Error::custom(
354                    "cannot serialize non-finite double in google.protobuf.Value",
355                ))
356            }
357        }
358        Some(prost_types::value::Kind::StringValue(value)) => serializer.serialize_str(value),
359        Some(prost_types::value::Kind::ListValue(value)) => {
360            serialize_list_inner(value, serializer, options)
361        }
362        Some(prost_types::value::Kind::StructValue(value)) => {
363            serialize_struct_inner(value, serializer, options)
364        }
365    }
366}
367
368fn serialize_struct_inner<S>(
369    raw: &prost_types::Struct,
370    serializer: S,
371    options: &SerializeOptions,
372) -> Result<S::Ok, S::Error>
373where
374    S: Serializer,
375{
376    let mut map = serializer.serialize_map(Some(raw.fields.len()))?;
377    for (key, value) in &raw.fields {
378        map.serialize_entry(key, &SerializeWrapper { value, options })?;
379    }
380    map.end()
381}
382
383fn serialize_list_inner<S>(
384    raw: &prost_types::ListValue,
385    serializer: S,
386    options: &SerializeOptions,
387) -> Result<S::Ok, S::Error>
388where
389    S: Serializer,
390{
391    let mut list = serializer.serialize_seq(Some(raw.values.len()))?;
392    for value in &raw.values {
393        list.serialize_element(&SerializeWrapper { value, options })?;
394    }
395    list.end()
396}
397
398fn decode_to_ser_err<E>(err: DecodeError) -> E
399where
400    E: Error,
401{
402    Error::custom(format!("error decoding: {}", err))
403}