prost_reflect/dynamic/
fields.rs

1use std::{
2    borrow::Cow,
3    collections::btree_map::{self, BTreeMap},
4    fmt,
5    mem::replace,
6};
7
8use crate::{
9    ExtensionDescriptor, FieldDescriptor, Kind, MessageDescriptor, OneofDescriptor, Value,
10};
11
12use super::{
13    unknown::{UnknownField, UnknownFieldSet},
14    Either,
15};
16
17pub(crate) trait FieldDescriptorLike: fmt::Debug {
18    #[cfg(feature = "text-format")]
19    fn text_name(&self) -> &str;
20    fn number(&self) -> u32;
21    fn default_value(&self) -> Value;
22    fn is_default_value(&self, value: &Value) -> bool;
23    fn is_valid(&self, value: &Value) -> bool;
24    fn containing_oneof(&self) -> Option<OneofDescriptor>;
25    fn supports_presence(&self) -> bool;
26    fn kind(&self) -> Kind;
27    fn is_group(&self) -> bool;
28    fn is_list(&self) -> bool;
29    fn is_map(&self) -> bool;
30    fn is_packed(&self) -> bool;
31    fn is_packable(&self) -> bool;
32    fn has(&self, value: &Value) -> bool {
33        self.supports_presence() || !self.is_default_value(value)
34    }
35}
36
37/// A set of extension fields in a protobuf message.
38#[derive(Default, Debug, Clone, PartialEq)]
39pub(super) struct DynamicMessageFieldSet {
40    fields: BTreeMap<u32, ValueOrUnknown>,
41}
42
43#[derive(Debug, Clone, PartialEq)]
44pub(super) enum ValueOrUnknown {
45    /// Used to implement draining iterators.
46    Taken,
47    /// A protobuf value with known field type.
48    Value(Value),
49    /// One or more unknown fields.
50    Unknown(UnknownFieldSet),
51}
52
53pub(super) enum ValueAndDescriptor<'a> {
54    Field(Cow<'a, Value>, FieldDescriptor),
55    Extension(Cow<'a, Value>, ExtensionDescriptor),
56    Unknown(&'a UnknownFieldSet),
57}
58
59impl DynamicMessageFieldSet {
60    fn get_value(&self, number: u32) -> Option<&Value> {
61        match self.fields.get(&number) {
62            Some(ValueOrUnknown::Value(value)) => Some(value),
63            Some(ValueOrUnknown::Unknown(_) | ValueOrUnknown::Taken) | None => None,
64        }
65    }
66
67    pub(super) fn has(&self, desc: &impl FieldDescriptorLike) -> bool {
68        self.get_value(desc.number())
69            .map(|value| desc.has(value))
70            .unwrap_or(false)
71    }
72
73    pub(super) fn get(&self, desc: &impl FieldDescriptorLike) -> Cow<'_, Value> {
74        match self.get_value(desc.number()) {
75            Some(value) => Cow::Borrowed(value),
76            None => Cow::Owned(desc.default_value()),
77        }
78    }
79
80    pub(super) fn get_mut(&mut self, desc: &impl FieldDescriptorLike) -> &mut Value {
81        self.clear_oneof_fields(desc);
82        match self.fields.entry(desc.number()) {
83            btree_map::Entry::Occupied(entry) => match entry.into_mut() {
84                ValueOrUnknown::Value(value) => value,
85                value => {
86                    *value = ValueOrUnknown::Value(desc.default_value());
87                    value.unwrap_value_mut()
88                }
89            },
90            btree_map::Entry::Vacant(entry) => entry
91                .insert(ValueOrUnknown::Value(desc.default_value()))
92                .unwrap_value_mut(),
93        }
94    }
95
96    pub(super) fn set(&mut self, desc: &impl FieldDescriptorLike, value: Value) {
97        debug_assert!(
98            desc.is_valid(&value),
99            "invalid value {:?} for field {:?}",
100            value,
101            desc,
102        );
103
104        self.clear_oneof_fields(desc);
105        self.fields
106            .insert(desc.number(), ValueOrUnknown::Value(value));
107    }
108
109    fn clear_oneof_fields(&mut self, desc: &impl FieldDescriptorLike) {
110        if let Some(oneof_desc) = desc.containing_oneof() {
111            for oneof_field in oneof_desc.fields() {
112                if oneof_field.number() != desc.number() {
113                    self.clear(&oneof_field);
114                }
115            }
116        }
117    }
118
119    pub(crate) fn add_unknown(&mut self, number: u32, unknown: UnknownField) {
120        match self.fields.entry(number) {
121            btree_map::Entry::Occupied(mut entry) => match entry.get_mut() {
122                ValueOrUnknown::Value(_) => {
123                    panic!("expected no field to be found with number {}", number)
124                }
125                value @ ValueOrUnknown::Taken => {
126                    *value = ValueOrUnknown::Unknown(UnknownFieldSet::from_iter([unknown]))
127                }
128                ValueOrUnknown::Unknown(unknowns) => unknowns.insert(unknown),
129            },
130            btree_map::Entry::Vacant(entry) => {
131                entry.insert(ValueOrUnknown::Unknown(UnknownFieldSet::from_iter([
132                    unknown,
133                ])));
134            }
135        }
136    }
137
138    pub(super) fn clear(&mut self, desc: &impl FieldDescriptorLike) {
139        self.fields.remove(&desc.number());
140    }
141
142    pub(crate) fn take(&mut self, desc: &impl FieldDescriptorLike) -> Option<Value> {
143        match self.fields.remove(&desc.number()) {
144            Some(ValueOrUnknown::Value(value)) if desc.has(&value) => Some(value),
145            _ => None,
146        }
147    }
148
149    /// Iterates over the fields in the message.
150    ///
151    /// If `include_default` is `true`, fields with their default value will be included.
152    /// If `index_order` is `true`, fields will be iterated in the order they were defined in the source code. Otherwise, they will be iterated in field number order.
153    pub(crate) fn iter<'a>(
154        &'a self,
155        message: &'a MessageDescriptor,
156        include_default: bool,
157        index_order: bool,
158    ) -> impl Iterator<Item = ValueAndDescriptor<'a>> + 'a {
159        let field_descriptors = if index_order {
160            Either::Left(message.fields_in_index_order())
161        } else {
162            Either::Right(message.fields())
163        };
164
165        let fields = field_descriptors
166            .filter(move |f| {
167                if include_default {
168                    !f.supports_presence() || self.has(f)
169                } else {
170                    self.has(f)
171                }
172            })
173            .map(|f| ValueAndDescriptor::Field(self.get(&f), f));
174
175        let extensions_unknowns =
176            self.fields
177                .iter()
178                .filter_map(move |(&number, value)| match value {
179                    ValueOrUnknown::Value(value) => {
180                        if let Some(extension) = message.get_extension(number) {
181                            if extension.has(value) {
182                                Some(ValueAndDescriptor::Extension(
183                                    Cow::Borrowed(value),
184                                    extension,
185                                ))
186                            } else {
187                                None
188                            }
189                        } else {
190                            None
191                        }
192                    }
193                    ValueOrUnknown::Unknown(unknown) => Some(ValueAndDescriptor::Unknown(unknown)),
194                    ValueOrUnknown::Taken => None,
195                });
196
197        fields.chain(extensions_unknowns)
198    }
199
200    pub(crate) fn iter_fields<'a>(
201        &'a self,
202        message: &'a MessageDescriptor,
203    ) -> impl Iterator<Item = (FieldDescriptor, &'a Value)> + 'a {
204        self.fields.iter().filter_map(move |(&number, value)| {
205            let value = match value {
206                ValueOrUnknown::Value(value) => value,
207                _ => return None,
208            };
209            let field = match message.get_field(number) {
210                Some(field) => field,
211                _ => return None,
212            };
213            if field.has(value) {
214                Some((field, value))
215            } else {
216                None
217            }
218        })
219    }
220
221    pub(crate) fn iter_extensions<'a>(
222        &'a self,
223        message: &'a MessageDescriptor,
224    ) -> impl Iterator<Item = (ExtensionDescriptor, &'a Value)> + 'a {
225        self.fields.iter().filter_map(move |(&number, value)| {
226            let value = match value {
227                ValueOrUnknown::Value(value) => value,
228                _ => return None,
229            };
230            let field = match message.get_extension(number) {
231                Some(field) => field,
232                _ => return None,
233            };
234            if field.has(value) {
235                Some((field, value))
236            } else {
237                None
238            }
239        })
240    }
241
242    pub(super) fn iter_unknown(&self) -> impl Iterator<Item = &'_ UnknownField> {
243        self.fields.values().flat_map(move |value| match value {
244            ValueOrUnknown::Taken | ValueOrUnknown::Value(_) => [].iter(),
245            ValueOrUnknown::Unknown(unknowns) => unknowns.iter(),
246        })
247    }
248
249    pub(crate) fn iter_fields_mut<'a>(
250        &'a mut self,
251        message: &'a MessageDescriptor,
252    ) -> impl Iterator<Item = (FieldDescriptor, &'a mut Value)> + 'a {
253        self.fields.iter_mut().filter_map(move |(&number, value)| {
254            let value = match value {
255                ValueOrUnknown::Value(value) => value,
256                _ => return None,
257            };
258            let field = match message.get_field(number) {
259                Some(field) => field,
260                _ => return None,
261            };
262            if field.has(value) {
263                Some((field, value))
264            } else {
265                None
266            }
267        })
268    }
269
270    pub(crate) fn iter_extensions_mut<'a>(
271        &'a mut self,
272        message: &'a MessageDescriptor,
273    ) -> impl Iterator<Item = (ExtensionDescriptor, &'a mut Value)> + 'a {
274        self.fields.iter_mut().filter_map(move |(&number, value)| {
275            let value = match value {
276                ValueOrUnknown::Value(value) => value,
277                _ => return None,
278            };
279            let field = match message.get_extension(number) {
280                Some(field) => field,
281                _ => return None,
282            };
283            if field.has(value) {
284                Some((field, value))
285            } else {
286                None
287            }
288        })
289    }
290
291    pub(crate) fn take_fields<'a>(
292        &'a mut self,
293        message: &'a MessageDescriptor,
294    ) -> impl Iterator<Item = (FieldDescriptor, Value)> + 'a {
295        self.fields
296            .iter_mut()
297            .filter_map(move |(&number, value_or_unknown)| {
298                let value = match value_or_unknown {
299                    ValueOrUnknown::Value(value) => value,
300                    _ => return None,
301                };
302                let field = match message.get_field(number) {
303                    Some(field) => field,
304                    _ => return None,
305                };
306                if field.has(value) {
307                    Some((
308                        field,
309                        replace(value_or_unknown, ValueOrUnknown::Taken).unwrap_value(),
310                    ))
311                } else {
312                    None
313                }
314            })
315    }
316
317    pub(crate) fn take_extensions<'a>(
318        &'a mut self,
319        message: &'a MessageDescriptor,
320    ) -> impl Iterator<Item = (ExtensionDescriptor, Value)> + 'a {
321        self.fields
322            .iter_mut()
323            .filter_map(move |(&number, value_or_unknown)| {
324                let value = match value_or_unknown {
325                    ValueOrUnknown::Value(value) => value,
326                    _ => return None,
327                };
328                let field = match message.get_extension(number) {
329                    Some(field) => field,
330                    _ => return None,
331                };
332                if field.has(value) {
333                    Some((
334                        field,
335                        replace(value_or_unknown, ValueOrUnknown::Taken).unwrap_value(),
336                    ))
337                } else {
338                    None
339                }
340            })
341    }
342
343    pub(crate) fn take_unknown(&mut self) -> impl Iterator<Item = UnknownField> + '_ {
344        self.fields
345            .values_mut()
346            .flat_map(move |value_or_unknown| match value_or_unknown {
347                ValueOrUnknown::Unknown(_) => replace(value_or_unknown, ValueOrUnknown::Taken)
348                    .unwrap_unknown()
349                    .into_iter(),
350                _ => vec![].into_iter(),
351            })
352    }
353
354    pub(super) fn clear_all(&mut self) {
355        self.fields.clear();
356    }
357}
358
359impl ValueOrUnknown {
360    fn unwrap_value_mut(&mut self) -> &mut Value {
361        match self {
362            ValueOrUnknown::Value(value) => value,
363            ValueOrUnknown::Unknown(_) | ValueOrUnknown::Taken => unreachable!(),
364        }
365    }
366
367    fn unwrap_value(self) -> Value {
368        match self {
369            ValueOrUnknown::Value(value) => value,
370            ValueOrUnknown::Unknown(_) | ValueOrUnknown::Taken => unreachable!(),
371        }
372    }
373
374    fn unwrap_unknown(self) -> UnknownFieldSet {
375        match self {
376            ValueOrUnknown::Unknown(unknowns) => unknowns,
377            ValueOrUnknown::Value(_) | ValueOrUnknown::Taken => unreachable!(),
378        }
379    }
380}
381
382impl FieldDescriptorLike for FieldDescriptor {
383    #[cfg(feature = "text-format")]
384    fn text_name(&self) -> &str {
385        self.name()
386    }
387
388    fn number(&self) -> u32 {
389        self.number()
390    }
391
392    fn default_value(&self) -> Value {
393        Value::default_value_for_field(self)
394    }
395
396    fn is_default_value(&self, value: &Value) -> bool {
397        value.is_default_for_field(self)
398    }
399
400    fn is_valid(&self, value: &Value) -> bool {
401        value.is_valid_for_field(self)
402    }
403
404    fn containing_oneof(&self) -> Option<OneofDescriptor> {
405        self.containing_oneof()
406    }
407
408    fn supports_presence(&self) -> bool {
409        self.supports_presence()
410    }
411
412    fn kind(&self) -> Kind {
413        self.kind()
414    }
415
416    fn is_group(&self) -> bool {
417        self.is_group()
418    }
419
420    fn is_list(&self) -> bool {
421        self.is_list()
422    }
423
424    fn is_map(&self) -> bool {
425        self.is_map()
426    }
427
428    fn is_packed(&self) -> bool {
429        self.is_packed()
430    }
431
432    fn is_packable(&self) -> bool {
433        self.is_packable()
434    }
435}
436
437impl FieldDescriptorLike for ExtensionDescriptor {
438    #[cfg(feature = "text-format")]
439    fn text_name(&self) -> &str {
440        self.json_name()
441    }
442
443    fn number(&self) -> u32 {
444        self.number()
445    }
446
447    fn default_value(&self) -> Value {
448        Value::default_value_for_extension(self)
449    }
450
451    fn is_default_value(&self, value: &Value) -> bool {
452        value.is_default_for_extension(self)
453    }
454
455    fn is_valid(&self, value: &Value) -> bool {
456        value.is_valid_for_extension(self)
457    }
458
459    fn containing_oneof(&self) -> Option<OneofDescriptor> {
460        None
461    }
462
463    fn supports_presence(&self) -> bool {
464        self.supports_presence()
465    }
466
467    fn kind(&self) -> Kind {
468        self.kind()
469    }
470
471    fn is_group(&self) -> bool {
472        self.is_group()
473    }
474
475    fn is_list(&self) -> bool {
476        self.is_list()
477    }
478
479    fn is_map(&self) -> bool {
480        self.is_map()
481    }
482
483    fn is_packed(&self) -> bool {
484        self.is_packed()
485    }
486
487    fn is_packable(&self) -> bool {
488        self.is_packable()
489    }
490}