prost_derive/field/
scalar.rs

1use std::fmt;
2
3use anyhow::{anyhow, bail, Error};
4use proc_macro2::{Span, TokenStream};
5use quote::{quote, ToTokens, TokenStreamExt};
6use syn::{parse_str, Expr, ExprLit, Ident, Index, Lit, LitByteStr, Meta, MetaNameValue, Path};
7
8use crate::field::{bool_attr, set_option, tag_attr, Label};
9
10/// A scalar protobuf field.
11#[derive(Clone)]
12pub struct Field {
13    pub ty: Ty,
14    pub kind: Kind,
15    pub tag: u32,
16}
17
18impl Field {
19    pub fn new(attrs: &[Meta], inferred_tag: Option<u32>) -> Result<Option<Field>, Error> {
20        let mut ty = None;
21        let mut label = None;
22        let mut packed = None;
23        let mut default = None;
24        let mut tag = None;
25
26        let mut unknown_attrs = Vec::new();
27
28        for attr in attrs {
29            if let Some(t) = Ty::from_attr(attr)? {
30                set_option(&mut ty, t, "duplicate type attributes")?;
31            } else if let Some(p) = bool_attr("packed", attr)? {
32                set_option(&mut packed, p, "duplicate packed attributes")?;
33            } else if let Some(t) = tag_attr(attr)? {
34                set_option(&mut tag, t, "duplicate tag attributes")?;
35            } else if let Some(l) = Label::from_attr(attr) {
36                set_option(&mut label, l, "duplicate label attributes")?;
37            } else if let Some(d) = DefaultValue::from_attr(attr)? {
38                set_option(&mut default, d, "duplicate default attributes")?;
39            } else {
40                unknown_attrs.push(attr);
41            }
42        }
43
44        let ty = match ty {
45            Some(ty) => ty,
46            None => return Ok(None),
47        };
48
49        if !unknown_attrs.is_empty() {
50            bail!(
51                "unknown attribute(s): #[prost({})]",
52                quote!(#(#unknown_attrs),*)
53            );
54        }
55
56        let tag = match tag.or(inferred_tag) {
57            Some(tag) => tag,
58            None => bail!("missing tag attribute"),
59        };
60
61        let has_default = default.is_some();
62        let default = default.map_or_else(
63            || Ok(DefaultValue::new(&ty)),
64            |lit| DefaultValue::from_lit(&ty, lit),
65        )?;
66
67        let kind = match (label, packed, has_default) {
68            (None, Some(true), _)
69            | (Some(Label::Optional), Some(true), _)
70            | (Some(Label::Required), Some(true), _) => {
71                bail!("packed attribute may only be applied to repeated fields");
72            }
73            (Some(Label::Repeated), Some(true), _) if !ty.is_numeric() => {
74                bail!("packed attribute may only be applied to numeric types");
75            }
76            (Some(Label::Repeated), _, true) => {
77                bail!("repeated fields may not have a default value");
78            }
79
80            (None, _, _) => Kind::Plain(default),
81            (Some(Label::Optional), _, _) => Kind::Optional(default),
82            (Some(Label::Required), _, _) => Kind::Required(default),
83            (Some(Label::Repeated), packed, false) if packed.unwrap_or_else(|| ty.is_numeric()) => {
84                Kind::Packed
85            }
86            (Some(Label::Repeated), _, false) => Kind::Repeated,
87        };
88
89        Ok(Some(Field { ty, kind, tag }))
90    }
91
92    pub fn new_oneof(attrs: &[Meta]) -> Result<Option<Field>, Error> {
93        if let Some(mut field) = Field::new(attrs, None)? {
94            match field.kind {
95                Kind::Plain(default) => {
96                    field.kind = Kind::Required(default);
97                    Ok(Some(field))
98                }
99                Kind::Optional(..) => bail!("invalid optional attribute on oneof field"),
100                Kind::Required(..) => bail!("invalid required attribute on oneof field"),
101                Kind::Packed | Kind::Repeated => bail!("invalid repeated attribute on oneof field"),
102            }
103        } else {
104            Ok(None)
105        }
106    }
107
108    pub fn encode(&self, ident: TokenStream) -> TokenStream {
109        let module = self.ty.module();
110        let encode_fn = match self.kind {
111            Kind::Plain(..) | Kind::Optional(..) | Kind::Required(..) => quote!(encode),
112            Kind::Repeated => quote!(encode_repeated),
113            Kind::Packed => quote!(encode_packed),
114        };
115        let encode_fn = quote!(::prost::encoding::#module::#encode_fn);
116        let tag = self.tag;
117
118        match self.kind {
119            Kind::Plain(ref default) => {
120                let default = default.typed();
121                quote! {
122                    if #ident != #default {
123                        #encode_fn(#tag, &#ident, buf);
124                    }
125                }
126            }
127            Kind::Optional(..) => quote! {
128                if let ::core::option::Option::Some(ref value) = #ident {
129                    #encode_fn(#tag, value, buf);
130                }
131            },
132            Kind::Required(..) | Kind::Repeated | Kind::Packed => quote! {
133                #encode_fn(#tag, &#ident, buf);
134            },
135        }
136    }
137
138    /// Returns an expression which evaluates to the result of merging a decoded
139    /// scalar value into the field.
140    pub fn merge(&self, ident: TokenStream) -> TokenStream {
141        let module = self.ty.module();
142        let merge_fn = match self.kind {
143            Kind::Plain(..) | Kind::Optional(..) | Kind::Required(..) => quote!(merge),
144            Kind::Repeated | Kind::Packed => quote!(merge_repeated),
145        };
146        let merge_fn = quote!(::prost::encoding::#module::#merge_fn);
147
148        match self.kind {
149            Kind::Plain(..) | Kind::Required(..) | Kind::Repeated | Kind::Packed => quote! {
150                #merge_fn(wire_type, #ident, buf, ctx)
151            },
152            Kind::Optional(..) => quote! {
153                #merge_fn(wire_type,
154                          #ident.get_or_insert_with(::core::default::Default::default),
155                          buf,
156                          ctx)
157            },
158        }
159    }
160
161    /// Returns an expression which evaluates to the encoded length of the field.
162    pub fn encoded_len(&self, ident: TokenStream) -> TokenStream {
163        let module = self.ty.module();
164        let encoded_len_fn = match self.kind {
165            Kind::Plain(..) | Kind::Optional(..) | Kind::Required(..) => quote!(encoded_len),
166            Kind::Repeated => quote!(encoded_len_repeated),
167            Kind::Packed => quote!(encoded_len_packed),
168        };
169        let encoded_len_fn = quote!(::prost::encoding::#module::#encoded_len_fn);
170        let tag = self.tag;
171
172        match self.kind {
173            Kind::Plain(ref default) => {
174                let default = default.typed();
175                quote! {
176                    if #ident != #default {
177                        #encoded_len_fn(#tag, &#ident)
178                    } else {
179                        0
180                    }
181                }
182            }
183            Kind::Optional(..) => quote! {
184                #ident.as_ref().map_or(0, |value| #encoded_len_fn(#tag, value))
185            },
186            Kind::Required(..) | Kind::Repeated | Kind::Packed => quote! {
187                #encoded_len_fn(#tag, &#ident)
188            },
189        }
190    }
191
192    pub fn clear(&self, ident: TokenStream) -> TokenStream {
193        match self.kind {
194            Kind::Plain(ref default) | Kind::Required(ref default) => {
195                let default = default.typed();
196                match self.ty {
197                    Ty::String | Ty::Bytes(..) => quote!(#ident.clear()),
198                    _ => quote!(#ident = #default),
199                }
200            }
201            Kind::Optional(_) => quote!(#ident = ::core::option::Option::None),
202            Kind::Repeated | Kind::Packed => quote!(#ident.clear()),
203        }
204    }
205
206    /// Returns an expression which evaluates to the default value of the field.
207    pub fn default(&self) -> TokenStream {
208        match self.kind {
209            Kind::Plain(ref value) | Kind::Required(ref value) => value.owned(),
210            Kind::Optional(_) => quote!(::core::option::Option::None),
211            Kind::Repeated | Kind::Packed => quote!(::prost::alloc::vec::Vec::new()),
212        }
213    }
214
215    /// An inner debug wrapper, around the base type.
216    fn debug_inner(&self, wrap_name: TokenStream) -> TokenStream {
217        if let Ty::Enumeration(ref ty) = self.ty {
218            quote! {
219                struct #wrap_name<'a>(&'a i32);
220                impl<'a> ::core::fmt::Debug for #wrap_name<'a> {
221                    fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
222                        let res: ::core::result::Result<#ty, _> = ::core::convert::TryFrom::try_from(*self.0);
223                        match res {
224                            Err(_) => ::core::fmt::Debug::fmt(&self.0, f),
225                            Ok(en) => ::core::fmt::Debug::fmt(&en, f),
226                        }
227                    }
228                }
229            }
230        } else {
231            quote! {
232                #[allow(non_snake_case)]
233                fn #wrap_name<T>(v: T) -> T { v }
234            }
235        }
236    }
237
238    /// Returns a fragment for formatting the field `ident` in `Debug`.
239    pub fn debug(&self, wrapper_name: TokenStream) -> TokenStream {
240        let wrapper = self.debug_inner(quote!(Inner));
241        let inner_ty = self.ty.rust_type();
242        match self.kind {
243            Kind::Plain(_) | Kind::Required(_) => self.debug_inner(wrapper_name),
244            Kind::Optional(_) => quote! {
245                struct #wrapper_name<'a>(&'a ::core::option::Option<#inner_ty>);
246                impl<'a> ::core::fmt::Debug for #wrapper_name<'a> {
247                    fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
248                        #wrapper
249                        ::core::fmt::Debug::fmt(&self.0.as_ref().map(Inner), f)
250                    }
251                }
252            },
253            Kind::Repeated | Kind::Packed => {
254                quote! {
255                    struct #wrapper_name<'a>(&'a ::prost::alloc::vec::Vec<#inner_ty>);
256                    impl<'a> ::core::fmt::Debug for #wrapper_name<'a> {
257                        fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
258                            let mut vec_builder = f.debug_list();
259                            for v in self.0 {
260                                #wrapper
261                                vec_builder.entry(&Inner(v));
262                            }
263                            vec_builder.finish()
264                        }
265                    }
266                }
267            }
268        }
269    }
270
271    /// Returns methods to embed in the message.
272    pub fn methods(&self, ident: &TokenStream) -> Option<TokenStream> {
273        let mut ident_str = ident.to_string();
274        if ident_str.starts_with("r#") {
275            ident_str = ident_str.split_off(2);
276        }
277
278        // Prepend `get_` for getter methods of tuple structs.
279        let get = match syn::parse_str::<Index>(&ident_str) {
280            Ok(index) => {
281                let get = Ident::new(&format!("get_{}", index.index), Span::call_site());
282                quote!(#get)
283            }
284            Err(_) => quote!(#ident),
285        };
286
287        if let Ty::Enumeration(ref ty) = self.ty {
288            let set = Ident::new(&format!("set_{}", ident_str), Span::call_site());
289            let set_doc = format!("Sets `{}` to the provided enum value.", ident_str);
290            Some(match self.kind {
291                Kind::Plain(ref default) | Kind::Required(ref default) => {
292                    let get_doc = format!(
293                        "Returns the enum value of `{}`, \
294                         or the default if the field is set to an invalid enum value.",
295                        ident_str,
296                    );
297                    quote! {
298                        #[doc=#get_doc]
299                        pub fn #get(&self) -> #ty {
300                            ::core::convert::TryFrom::try_from(self.#ident).unwrap_or(#default)
301                        }
302
303                        #[doc=#set_doc]
304                        pub fn #set(&mut self, value: #ty) {
305                            self.#ident = value as i32;
306                        }
307                    }
308                }
309                Kind::Optional(ref default) => {
310                    let get_doc = format!(
311                        "Returns the enum value of `{}`, \
312                         or the default if the field is unset or set to an invalid enum value.",
313                        ident_str,
314                    );
315                    quote! {
316                        #[doc=#get_doc]
317                        pub fn #get(&self) -> #ty {
318                            self.#ident.and_then(|x| {
319                                let result: ::core::result::Result<#ty, _> = ::core::convert::TryFrom::try_from(x);
320                                result.ok()
321                            }).unwrap_or(#default)
322                        }
323
324                        #[doc=#set_doc]
325                        pub fn #set(&mut self, value: #ty) {
326                            self.#ident = ::core::option::Option::Some(value as i32);
327                        }
328                    }
329                }
330                Kind::Repeated | Kind::Packed => {
331                    let iter_doc = format!(
332                        "Returns an iterator which yields the valid enum values contained in `{}`.",
333                        ident_str,
334                    );
335                    let push = Ident::new(&format!("push_{}", ident_str), Span::call_site());
336                    let push_doc = format!("Appends the provided enum value to `{}`.", ident_str);
337                    quote! {
338                        #[doc=#iter_doc]
339                        pub fn #get(&self) -> ::core::iter::FilterMap<
340                            ::core::iter::Cloned<::core::slice::Iter<i32>>,
341                            fn(i32) -> ::core::option::Option<#ty>,
342                        > {
343                            self.#ident.iter().cloned().filter_map(|x| {
344                                let result: ::core::result::Result<#ty, _> = ::core::convert::TryFrom::try_from(x);
345                                result.ok()
346                            })
347                        }
348                        #[doc=#push_doc]
349                        pub fn #push(&mut self, value: #ty) {
350                            self.#ident.push(value as i32);
351                        }
352                    }
353                }
354            })
355        } else if let Kind::Optional(ref default) = self.kind {
356            let ty = self.ty.rust_ref_type();
357
358            let match_some = if self.ty.is_numeric() {
359                quote!(::core::option::Option::Some(val) => val,)
360            } else {
361                quote!(::core::option::Option::Some(ref val) => &val[..],)
362            };
363
364            let get_doc = format!(
365                "Returns the value of `{0}`, or the default value if `{0}` is unset.",
366                ident_str,
367            );
368
369            Some(quote! {
370                #[doc=#get_doc]
371                pub fn #get(&self) -> #ty {
372                    match self.#ident {
373                        #match_some
374                        ::core::option::Option::None => #default,
375                    }
376                }
377            })
378        } else {
379            None
380        }
381    }
382}
383
384/// A scalar protobuf field type.
385#[derive(Clone, PartialEq, Eq)]
386pub enum Ty {
387    Double,
388    Float,
389    Int32,
390    Int64,
391    Uint32,
392    Uint64,
393    Sint32,
394    Sint64,
395    Fixed32,
396    Fixed64,
397    Sfixed32,
398    Sfixed64,
399    Bool,
400    String,
401    Bytes(BytesTy),
402    Enumeration(Path),
403}
404
405#[derive(Clone, Debug, PartialEq, Eq)]
406pub enum BytesTy {
407    Vec,
408    Bytes,
409}
410
411impl BytesTy {
412    fn try_from_str(s: &str) -> Result<Self, Error> {
413        match s {
414            "vec" => Ok(BytesTy::Vec),
415            "bytes" => Ok(BytesTy::Bytes),
416            _ => bail!("Invalid bytes type: {}", s),
417        }
418    }
419
420    fn rust_type(&self) -> TokenStream {
421        match self {
422            BytesTy::Vec => quote! { ::prost::alloc::vec::Vec<u8> },
423            BytesTy::Bytes => quote! { ::prost::bytes::Bytes },
424        }
425    }
426}
427
428impl Ty {
429    pub fn from_attr(attr: &Meta) -> Result<Option<Ty>, Error> {
430        let ty = match *attr {
431            Meta::Path(ref name) if name.is_ident("float") => Ty::Float,
432            Meta::Path(ref name) if name.is_ident("double") => Ty::Double,
433            Meta::Path(ref name) if name.is_ident("int32") => Ty::Int32,
434            Meta::Path(ref name) if name.is_ident("int64") => Ty::Int64,
435            Meta::Path(ref name) if name.is_ident("uint32") => Ty::Uint32,
436            Meta::Path(ref name) if name.is_ident("uint64") => Ty::Uint64,
437            Meta::Path(ref name) if name.is_ident("sint32") => Ty::Sint32,
438            Meta::Path(ref name) if name.is_ident("sint64") => Ty::Sint64,
439            Meta::Path(ref name) if name.is_ident("fixed32") => Ty::Fixed32,
440            Meta::Path(ref name) if name.is_ident("fixed64") => Ty::Fixed64,
441            Meta::Path(ref name) if name.is_ident("sfixed32") => Ty::Sfixed32,
442            Meta::Path(ref name) if name.is_ident("sfixed64") => Ty::Sfixed64,
443            Meta::Path(ref name) if name.is_ident("bool") => Ty::Bool,
444            Meta::Path(ref name) if name.is_ident("string") => Ty::String,
445            Meta::Path(ref name) if name.is_ident("bytes") => Ty::Bytes(BytesTy::Vec),
446            Meta::NameValue(MetaNameValue {
447                ref path,
448                value:
449                    Expr::Lit(ExprLit {
450                        lit: Lit::Str(ref l),
451                        ..
452                    }),
453                ..
454            }) if path.is_ident("bytes") => Ty::Bytes(BytesTy::try_from_str(&l.value())?),
455            Meta::NameValue(MetaNameValue {
456                ref path,
457                value:
458                    Expr::Lit(ExprLit {
459                        lit: Lit::Str(ref l),
460                        ..
461                    }),
462                ..
463            }) if path.is_ident("enumeration") => Ty::Enumeration(parse_str::<Path>(&l.value())?),
464            Meta::List(ref meta_list) if meta_list.path.is_ident("enumeration") => {
465                Ty::Enumeration(meta_list.parse_args::<Path>()?)
466            }
467            _ => return Ok(None),
468        };
469        Ok(Some(ty))
470    }
471
472    pub fn from_str(s: &str) -> Result<Ty, Error> {
473        let enumeration_len = "enumeration".len();
474        let error = Err(anyhow!("invalid type: {}", s));
475        let ty = match s.trim() {
476            "float" => Ty::Float,
477            "double" => Ty::Double,
478            "int32" => Ty::Int32,
479            "int64" => Ty::Int64,
480            "uint32" => Ty::Uint32,
481            "uint64" => Ty::Uint64,
482            "sint32" => Ty::Sint32,
483            "sint64" => Ty::Sint64,
484            "fixed32" => Ty::Fixed32,
485            "fixed64" => Ty::Fixed64,
486            "sfixed32" => Ty::Sfixed32,
487            "sfixed64" => Ty::Sfixed64,
488            "bool" => Ty::Bool,
489            "string" => Ty::String,
490            "bytes" => Ty::Bytes(BytesTy::Vec),
491            s if s.len() > enumeration_len && &s[..enumeration_len] == "enumeration" => {
492                let s = &s[enumeration_len..].trim();
493                match s.chars().next() {
494                    Some('<') | Some('(') => (),
495                    _ => return error,
496                }
497                match s.chars().next_back() {
498                    Some('>') | Some(')') => (),
499                    _ => return error,
500                }
501
502                Ty::Enumeration(parse_str::<Path>(s[1..s.len() - 1].trim())?)
503            }
504            _ => return error,
505        };
506        Ok(ty)
507    }
508
509    /// Returns the type as it appears in protobuf field declarations.
510    pub fn as_str(&self) -> &'static str {
511        match *self {
512            Ty::Double => "double",
513            Ty::Float => "float",
514            Ty::Int32 => "int32",
515            Ty::Int64 => "int64",
516            Ty::Uint32 => "uint32",
517            Ty::Uint64 => "uint64",
518            Ty::Sint32 => "sint32",
519            Ty::Sint64 => "sint64",
520            Ty::Fixed32 => "fixed32",
521            Ty::Fixed64 => "fixed64",
522            Ty::Sfixed32 => "sfixed32",
523            Ty::Sfixed64 => "sfixed64",
524            Ty::Bool => "bool",
525            Ty::String => "string",
526            Ty::Bytes(..) => "bytes",
527            Ty::Enumeration(..) => "enum",
528        }
529    }
530
531    // TODO: rename to 'owned_type'.
532    pub fn rust_type(&self) -> TokenStream {
533        match self {
534            Ty::String => quote!(::prost::alloc::string::String),
535            Ty::Bytes(ty) => ty.rust_type(),
536            _ => self.rust_ref_type(),
537        }
538    }
539
540    // TODO: rename to 'ref_type'
541    pub fn rust_ref_type(&self) -> TokenStream {
542        match *self {
543            Ty::Double => quote!(f64),
544            Ty::Float => quote!(f32),
545            Ty::Int32 => quote!(i32),
546            Ty::Int64 => quote!(i64),
547            Ty::Uint32 => quote!(u32),
548            Ty::Uint64 => quote!(u64),
549            Ty::Sint32 => quote!(i32),
550            Ty::Sint64 => quote!(i64),
551            Ty::Fixed32 => quote!(u32),
552            Ty::Fixed64 => quote!(u64),
553            Ty::Sfixed32 => quote!(i32),
554            Ty::Sfixed64 => quote!(i64),
555            Ty::Bool => quote!(bool),
556            Ty::String => quote!(&str),
557            Ty::Bytes(..) => quote!(&[u8]),
558            Ty::Enumeration(..) => quote!(i32),
559        }
560    }
561
562    pub fn module(&self) -> Ident {
563        match *self {
564            Ty::Enumeration(..) => Ident::new("int32", Span::call_site()),
565            _ => Ident::new(self.as_str(), Span::call_site()),
566        }
567    }
568
569    /// Returns false if the scalar type is length delimited (i.e., `string` or `bytes`).
570    pub fn is_numeric(&self) -> bool {
571        !matches!(self, Ty::String | Ty::Bytes(..))
572    }
573}
574
575impl fmt::Debug for Ty {
576    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
577        f.write_str(self.as_str())
578    }
579}
580
581impl fmt::Display for Ty {
582    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
583        f.write_str(self.as_str())
584    }
585}
586
587/// Scalar Protobuf field types.
588#[derive(Clone)]
589pub enum Kind {
590    /// A plain proto3 scalar field.
591    Plain(DefaultValue),
592    /// An optional scalar field.
593    Optional(DefaultValue),
594    /// A required proto2 scalar field.
595    Required(DefaultValue),
596    /// A repeated scalar field.
597    Repeated,
598    /// A packed repeated scalar field.
599    Packed,
600}
601
602/// Scalar Protobuf field default value.
603#[derive(Clone, Debug)]
604pub enum DefaultValue {
605    F64(f64),
606    F32(f32),
607    I32(i32),
608    I64(i64),
609    U32(u32),
610    U64(u64),
611    Bool(bool),
612    String(String),
613    Bytes(Vec<u8>),
614    Enumeration(TokenStream),
615    Path(Path),
616}
617
618impl DefaultValue {
619    pub fn from_attr(attr: &Meta) -> Result<Option<Lit>, Error> {
620        if !attr.path().is_ident("default") {
621            Ok(None)
622        } else if let Meta::NameValue(MetaNameValue {
623            value: Expr::Lit(ExprLit { ref lit, .. }),
624            ..
625        }) = *attr
626        {
627            Ok(Some(lit.clone()))
628        } else {
629            bail!("invalid default value attribute: {:?}", attr)
630        }
631    }
632
633    pub fn from_lit(ty: &Ty, lit: Lit) -> Result<DefaultValue, Error> {
634        let is_i32 = *ty == Ty::Int32 || *ty == Ty::Sint32 || *ty == Ty::Sfixed32;
635        let is_i64 = *ty == Ty::Int64 || *ty == Ty::Sint64 || *ty == Ty::Sfixed64;
636
637        let is_u32 = *ty == Ty::Uint32 || *ty == Ty::Fixed32;
638        let is_u64 = *ty == Ty::Uint64 || *ty == Ty::Fixed64;
639
640        let empty_or_is = |expected, actual: &str| expected == actual || actual.is_empty();
641
642        let default = match lit {
643            Lit::Int(ref lit) if is_i32 && empty_or_is("i32", lit.suffix()) => {
644                DefaultValue::I32(lit.base10_parse()?)
645            }
646            Lit::Int(ref lit) if is_i64 && empty_or_is("i64", lit.suffix()) => {
647                DefaultValue::I64(lit.base10_parse()?)
648            }
649            Lit::Int(ref lit) if is_u32 && empty_or_is("u32", lit.suffix()) => {
650                DefaultValue::U32(lit.base10_parse()?)
651            }
652            Lit::Int(ref lit) if is_u64 && empty_or_is("u64", lit.suffix()) => {
653                DefaultValue::U64(lit.base10_parse()?)
654            }
655
656            Lit::Float(ref lit) if *ty == Ty::Float && empty_or_is("f32", lit.suffix()) => {
657                DefaultValue::F32(lit.base10_parse()?)
658            }
659            Lit::Int(ref lit) if *ty == Ty::Float => DefaultValue::F32(lit.base10_parse()?),
660
661            Lit::Float(ref lit) if *ty == Ty::Double && empty_or_is("f64", lit.suffix()) => {
662                DefaultValue::F64(lit.base10_parse()?)
663            }
664            Lit::Int(ref lit) if *ty == Ty::Double => DefaultValue::F64(lit.base10_parse()?),
665
666            Lit::Bool(ref lit) if *ty == Ty::Bool => DefaultValue::Bool(lit.value),
667            Lit::Str(ref lit) if *ty == Ty::String => DefaultValue::String(lit.value()),
668            Lit::ByteStr(ref lit)
669                if *ty == Ty::Bytes(BytesTy::Bytes) || *ty == Ty::Bytes(BytesTy::Vec) =>
670            {
671                DefaultValue::Bytes(lit.value())
672            }
673
674            Lit::Str(ref lit) => {
675                let value = lit.value();
676                let value = value.trim();
677
678                if let Ty::Enumeration(ref path) = *ty {
679                    let variant = Ident::new(value, Span::call_site());
680                    return Ok(DefaultValue::Enumeration(quote!(#path::#variant)));
681                }
682
683                // Parse special floating point values.
684                if *ty == Ty::Float {
685                    match value {
686                        "inf" => {
687                            return Ok(DefaultValue::Path(parse_str::<Path>(
688                                "::core::f32::INFINITY",
689                            )?));
690                        }
691                        "-inf" => {
692                            return Ok(DefaultValue::Path(parse_str::<Path>(
693                                "::core::f32::NEG_INFINITY",
694                            )?));
695                        }
696                        "nan" => {
697                            return Ok(DefaultValue::Path(parse_str::<Path>("::core::f32::NAN")?));
698                        }
699                        _ => (),
700                    }
701                }
702                if *ty == Ty::Double {
703                    match value {
704                        "inf" => {
705                            return Ok(DefaultValue::Path(parse_str::<Path>(
706                                "::core::f64::INFINITY",
707                            )?));
708                        }
709                        "-inf" => {
710                            return Ok(DefaultValue::Path(parse_str::<Path>(
711                                "::core::f64::NEG_INFINITY",
712                            )?));
713                        }
714                        "nan" => {
715                            return Ok(DefaultValue::Path(parse_str::<Path>("::core::f64::NAN")?));
716                        }
717                        _ => (),
718                    }
719                }
720
721                // Rust doesn't have a negative literals, so they have to be parsed specially.
722                if let Some(Ok(lit)) = value.strip_prefix('-').map(syn::parse_str::<Lit>) {
723                    match lit {
724                        Lit::Int(ref lit) if is_i32 && empty_or_is("i32", lit.suffix()) => {
725                            // Initially parse into an i64, so that i32::MIN does not overflow.
726                            let value: i64 = -lit.base10_parse()?;
727                            return Ok(i32::try_from(value).map(DefaultValue::I32)?);
728                        }
729                        Lit::Int(ref lit) if is_i64 && empty_or_is("i64", lit.suffix()) => {
730                            // Initially parse into an i128, so that i64::MIN does not overflow.
731                            let value: i128 = -lit.base10_parse()?;
732                            return Ok(i64::try_from(value).map(DefaultValue::I64)?);
733                        }
734                        Lit::Float(ref lit)
735                            if *ty == Ty::Float && empty_or_is("f32", lit.suffix()) =>
736                        {
737                            return Ok(DefaultValue::F32(-lit.base10_parse()?));
738                        }
739                        Lit::Float(ref lit)
740                            if *ty == Ty::Double && empty_or_is("f64", lit.suffix()) =>
741                        {
742                            return Ok(DefaultValue::F64(-lit.base10_parse()?));
743                        }
744                        Lit::Int(ref lit) if *ty == Ty::Float && lit.suffix().is_empty() => {
745                            return Ok(DefaultValue::F32(-lit.base10_parse()?));
746                        }
747                        Lit::Int(ref lit) if *ty == Ty::Double && lit.suffix().is_empty() => {
748                            return Ok(DefaultValue::F64(-lit.base10_parse()?));
749                        }
750                        _ => (),
751                    }
752                }
753                match syn::parse_str::<Lit>(value) {
754                    Ok(Lit::Str(_)) => (),
755                    Ok(lit) => return DefaultValue::from_lit(ty, lit),
756                    _ => (),
757                }
758                bail!("invalid default value: {}", quote!(#value));
759            }
760            _ => bail!("invalid default value: {}", quote!(#lit)),
761        };
762
763        Ok(default)
764    }
765
766    pub fn new(ty: &Ty) -> DefaultValue {
767        match *ty {
768            Ty::Float => DefaultValue::F32(0.0),
769            Ty::Double => DefaultValue::F64(0.0),
770            Ty::Int32 | Ty::Sint32 | Ty::Sfixed32 => DefaultValue::I32(0),
771            Ty::Int64 | Ty::Sint64 | Ty::Sfixed64 => DefaultValue::I64(0),
772            Ty::Uint32 | Ty::Fixed32 => DefaultValue::U32(0),
773            Ty::Uint64 | Ty::Fixed64 => DefaultValue::U64(0),
774
775            Ty::Bool => DefaultValue::Bool(false),
776            Ty::String => DefaultValue::String(String::new()),
777            Ty::Bytes(..) => DefaultValue::Bytes(Vec::new()),
778            Ty::Enumeration(ref path) => DefaultValue::Enumeration(quote!(#path::default())),
779        }
780    }
781
782    pub fn owned(&self) -> TokenStream {
783        match *self {
784            DefaultValue::String(ref value) if value.is_empty() => {
785                quote!(::prost::alloc::string::String::new())
786            }
787            DefaultValue::String(ref value) => quote!(#value.into()),
788            DefaultValue::Bytes(ref value) if value.is_empty() => {
789                quote!(::core::default::Default::default())
790            }
791            DefaultValue::Bytes(ref value) => {
792                let lit = LitByteStr::new(value, Span::call_site());
793                quote!(#lit.as_ref().into())
794            }
795
796            ref other => other.typed(),
797        }
798    }
799
800    pub fn typed(&self) -> TokenStream {
801        if let DefaultValue::Enumeration(_) = *self {
802            quote!(#self as i32)
803        } else {
804            quote!(#self)
805        }
806    }
807}
808
809impl ToTokens for DefaultValue {
810    fn to_tokens(&self, tokens: &mut TokenStream) {
811        match *self {
812            DefaultValue::F64(value) => value.to_tokens(tokens),
813            DefaultValue::F32(value) => value.to_tokens(tokens),
814            DefaultValue::I32(value) => value.to_tokens(tokens),
815            DefaultValue::I64(value) => value.to_tokens(tokens),
816            DefaultValue::U32(value) => value.to_tokens(tokens),
817            DefaultValue::U64(value) => value.to_tokens(tokens),
818            DefaultValue::Bool(value) => value.to_tokens(tokens),
819            DefaultValue::String(ref value) => value.to_tokens(tokens),
820            DefaultValue::Bytes(ref value) => {
821                let byte_str = LitByteStr::new(value, Span::call_site());
822                tokens.append_all(quote!(#byte_str as &[u8]));
823            }
824            DefaultValue::Enumeration(ref value) => value.to_tokens(tokens),
825            DefaultValue::Path(ref value) => value.to_tokens(tokens),
826        }
827    }
828}