prost_derive/
lib.rs

1#![doc(html_root_url = "https://docs.rs/prost-derive/0.13.5")]
2// The `quote!` macro requires deep recursion.
3#![recursion_limit = "4096"]
4
5extern crate alloc;
6extern crate proc_macro;
7
8use anyhow::{bail, Error};
9use itertools::Itertools;
10use proc_macro2::{Span, TokenStream};
11use quote::quote;
12use syn::{
13    punctuated::Punctuated, Data, DataEnum, DataStruct, DeriveInput, Expr, Fields, FieldsNamed,
14    FieldsUnnamed, Ident, Index, Variant,
15};
16
17mod field;
18use crate::field::Field;
19
20fn try_message(input: TokenStream) -> Result<TokenStream, Error> {
21    let input: DeriveInput = syn::parse2(input)?;
22
23    let ident = input.ident;
24
25    syn::custom_keyword!(skip_debug);
26    let skip_debug = input
27        .attrs
28        .into_iter()
29        .any(|a| a.path().is_ident("prost") && a.parse_args::<skip_debug>().is_ok());
30
31    let variant_data = match input.data {
32        Data::Struct(variant_data) => variant_data,
33        Data::Enum(..) => bail!("Message can not be derived for an enum"),
34        Data::Union(..) => bail!("Message can not be derived for a union"),
35    };
36
37    let generics = &input.generics;
38    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
39
40    let (is_struct, fields) = match variant_data {
41        DataStruct {
42            fields: Fields::Named(FieldsNamed { named: fields, .. }),
43            ..
44        } => (true, fields.into_iter().collect()),
45        DataStruct {
46            fields:
47                Fields::Unnamed(FieldsUnnamed {
48                    unnamed: fields, ..
49                }),
50            ..
51        } => (false, fields.into_iter().collect()),
52        DataStruct {
53            fields: Fields::Unit,
54            ..
55        } => (false, Vec::new()),
56    };
57
58    let mut next_tag: u32 = 1;
59    let mut fields = fields
60        .into_iter()
61        .enumerate()
62        .flat_map(|(i, field)| {
63            let field_ident = field.ident.map(|x| quote!(#x)).unwrap_or_else(|| {
64                let index = Index {
65                    index: i as u32,
66                    span: Span::call_site(),
67                };
68                quote!(#index)
69            });
70            match Field::new(field.attrs, Some(next_tag)) {
71                Ok(Some(field)) => {
72                    next_tag = field.tags().iter().max().map(|t| t + 1).unwrap_or(next_tag);
73                    Some(Ok((field_ident, field)))
74                }
75                Ok(None) => None,
76                Err(err) => Some(Err(
77                    err.context(format!("invalid message field {}.{}", ident, field_ident))
78                )),
79            }
80        })
81        .collect::<Result<Vec<_>, _>>()?;
82
83    // We want Debug to be in declaration order
84    let unsorted_fields = fields.clone();
85
86    // Sort the fields by tag number so that fields will be encoded in tag order.
87    // TODO: This encodes oneof fields in the position of their lowest tag,
88    // regardless of the currently occupied variant, is that consequential?
89    // See: https://developers.google.com/protocol-buffers/docs/encoding#order
90    fields.sort_by_key(|(_, field)| field.tags().into_iter().min().unwrap());
91    let fields = fields;
92
93    if let Some(duplicate_tag) = fields
94        .iter()
95        .flat_map(|(_, field)| field.tags())
96        .duplicates()
97        .next()
98    {
99        bail!(
100            "message {} has multiple fields with tag {}",
101            ident,
102            duplicate_tag
103        )
104    };
105
106    let encoded_len = fields
107        .iter()
108        .map(|(field_ident, field)| field.encoded_len(quote!(self.#field_ident)));
109
110    let encode = fields
111        .iter()
112        .map(|(field_ident, field)| field.encode(quote!(self.#field_ident)));
113
114    let merge = fields.iter().map(|(field_ident, field)| {
115        let merge = field.merge(quote!(value));
116        let tags = field.tags().into_iter().map(|tag| quote!(#tag));
117        let tags = Itertools::intersperse(tags, quote!(|));
118
119        quote! {
120            #(#tags)* => {
121                let mut value = &mut self.#field_ident;
122                #merge.map_err(|mut error| {
123                    error.push(STRUCT_NAME, stringify!(#field_ident));
124                    error
125                })
126            },
127        }
128    });
129
130    let struct_name = if fields.is_empty() {
131        quote!()
132    } else {
133        quote!(
134            const STRUCT_NAME: &'static str = stringify!(#ident);
135        )
136    };
137
138    let clear = fields
139        .iter()
140        .map(|(field_ident, field)| field.clear(quote!(self.#field_ident)));
141
142    let default = if is_struct {
143        let default = fields.iter().map(|(field_ident, field)| {
144            let value = field.default();
145            quote!(#field_ident: #value,)
146        });
147        quote! {#ident {
148            #(#default)*
149        }}
150    } else {
151        let default = fields.iter().map(|(_, field)| {
152            let value = field.default();
153            quote!(#value,)
154        });
155        quote! {#ident (
156            #(#default)*
157        )}
158    };
159
160    let methods = fields
161        .iter()
162        .flat_map(|(field_ident, field)| field.methods(field_ident))
163        .collect::<Vec<_>>();
164    let methods = if methods.is_empty() {
165        quote!()
166    } else {
167        quote! {
168            #[allow(dead_code)]
169            impl #impl_generics #ident #ty_generics #where_clause {
170                #(#methods)*
171            }
172        }
173    };
174
175    let expanded = quote! {
176        impl #impl_generics ::prost::Message for #ident #ty_generics #where_clause {
177            #[allow(unused_variables)]
178            fn encode_raw(&self, buf: &mut impl ::prost::bytes::BufMut) {
179                #(#encode)*
180            }
181
182            #[allow(unused_variables)]
183            fn merge_field(
184                &mut self,
185                tag: u32,
186                wire_type: ::prost::encoding::wire_type::WireType,
187                buf: &mut impl ::prost::bytes::Buf,
188                ctx: ::prost::encoding::DecodeContext,
189            ) -> ::core::result::Result<(), ::prost::DecodeError>
190            {
191                #struct_name
192                match tag {
193                    #(#merge)*
194                    _ => ::prost::encoding::skip_field(wire_type, tag, buf, ctx),
195                }
196            }
197
198            #[inline]
199            fn encoded_len(&self) -> usize {
200                0 #(+ #encoded_len)*
201            }
202
203            fn clear(&mut self) {
204                #(#clear;)*
205            }
206        }
207
208        impl #impl_generics ::core::default::Default for #ident #ty_generics #where_clause {
209            fn default() -> Self {
210                #default
211            }
212        }
213    };
214    let expanded = if skip_debug {
215        expanded
216    } else {
217        let debugs = unsorted_fields.iter().map(|(field_ident, field)| {
218            let wrapper = field.debug(quote!(self.#field_ident));
219            let call = if is_struct {
220                quote!(builder.field(stringify!(#field_ident), &wrapper))
221            } else {
222                quote!(builder.field(&wrapper))
223            };
224            quote! {
225                 let builder = {
226                     let wrapper = #wrapper;
227                     #call
228                 };
229            }
230        });
231        let debug_builder = if is_struct {
232            quote!(f.debug_struct(stringify!(#ident)))
233        } else {
234            quote!(f.debug_tuple(stringify!(#ident)))
235        };
236        quote! {
237            #expanded
238
239            impl #impl_generics ::core::fmt::Debug for #ident #ty_generics #where_clause {
240                fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
241                    let mut builder = #debug_builder;
242                    #(#debugs;)*
243                    builder.finish()
244                }
245            }
246        }
247    };
248
249    let expanded = quote! {
250        #expanded
251
252        #methods
253    };
254
255    Ok(expanded)
256}
257
258#[proc_macro_derive(Message, attributes(prost))]
259pub fn message(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
260    try_message(input.into()).unwrap().into()
261}
262
263fn try_enumeration(input: TokenStream) -> Result<TokenStream, Error> {
264    let input: DeriveInput = syn::parse2(input)?;
265    let ident = input.ident;
266
267    let generics = &input.generics;
268    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
269
270    let punctuated_variants = match input.data {
271        Data::Enum(DataEnum { variants, .. }) => variants,
272        Data::Struct(_) => bail!("Enumeration can not be derived for a struct"),
273        Data::Union(..) => bail!("Enumeration can not be derived for a union"),
274    };
275
276    // Map the variants into 'fields'.
277    let mut variants: Vec<(Ident, Expr)> = Vec::new();
278    for Variant {
279        ident,
280        fields,
281        discriminant,
282        ..
283    } in punctuated_variants
284    {
285        match fields {
286            Fields::Unit => (),
287            Fields::Named(_) | Fields::Unnamed(_) => {
288                bail!("Enumeration variants may not have fields")
289            }
290        }
291
292        match discriminant {
293            Some((_, expr)) => variants.push((ident, expr)),
294            None => bail!("Enumeration variants must have a discriminant"),
295        }
296    }
297
298    if variants.is_empty() {
299        panic!("Enumeration must have at least one variant");
300    }
301
302    let default = variants[0].0.clone();
303
304    let is_valid = variants.iter().map(|(_, value)| quote!(#value => true));
305    let from = variants
306        .iter()
307        .map(|(variant, value)| quote!(#value => ::core::option::Option::Some(#ident::#variant)));
308
309    let try_from = variants
310        .iter()
311        .map(|(variant, value)| quote!(#value => ::core::result::Result::Ok(#ident::#variant)));
312
313    let is_valid_doc = format!("Returns `true` if `value` is a variant of `{}`.", ident);
314    let from_i32_doc = format!(
315        "Converts an `i32` to a `{}`, or `None` if `value` is not a valid variant.",
316        ident
317    );
318
319    let expanded = quote! {
320        impl #impl_generics #ident #ty_generics #where_clause {
321            #[doc=#is_valid_doc]
322            pub fn is_valid(value: i32) -> bool {
323                match value {
324                    #(#is_valid,)*
325                    _ => false,
326                }
327            }
328
329            #[deprecated = "Use the TryFrom<i32> implementation instead"]
330            #[doc=#from_i32_doc]
331            pub fn from_i32(value: i32) -> ::core::option::Option<#ident> {
332                match value {
333                    #(#from,)*
334                    _ => ::core::option::Option::None,
335                }
336            }
337        }
338
339        impl #impl_generics ::core::default::Default for #ident #ty_generics #where_clause {
340            fn default() -> #ident {
341                #ident::#default
342            }
343        }
344
345        impl #impl_generics ::core::convert::From::<#ident> for i32 #ty_generics #where_clause {
346            fn from(value: #ident) -> i32 {
347                value as i32
348            }
349        }
350
351        impl #impl_generics ::core::convert::TryFrom::<i32> for #ident #ty_generics #where_clause {
352            type Error = ::prost::UnknownEnumValue;
353
354            fn try_from(value: i32) -> ::core::result::Result<#ident, ::prost::UnknownEnumValue> {
355                match value {
356                    #(#try_from,)*
357                    _ => ::core::result::Result::Err(::prost::UnknownEnumValue(value)),
358                }
359            }
360        }
361    };
362
363    Ok(expanded)
364}
365
366#[proc_macro_derive(Enumeration, attributes(prost))]
367pub fn enumeration(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
368    try_enumeration(input.into()).unwrap().into()
369}
370
371fn try_oneof(input: TokenStream) -> Result<TokenStream, Error> {
372    let input: DeriveInput = syn::parse2(input)?;
373
374    let ident = input.ident;
375
376    syn::custom_keyword!(skip_debug);
377    let skip_debug = input
378        .attrs
379        .into_iter()
380        .any(|a| a.path().is_ident("prost") && a.parse_args::<skip_debug>().is_ok());
381
382    let variants = match input.data {
383        Data::Enum(DataEnum { variants, .. }) => variants,
384        Data::Struct(..) => bail!("Oneof can not be derived for a struct"),
385        Data::Union(..) => bail!("Oneof can not be derived for a union"),
386    };
387
388    let generics = &input.generics;
389    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
390
391    // Map the variants into 'fields'.
392    let mut fields: Vec<(Ident, Field)> = Vec::new();
393    for Variant {
394        attrs,
395        ident: variant_ident,
396        fields: variant_fields,
397        ..
398    } in variants
399    {
400        let variant_fields = match variant_fields {
401            Fields::Unit => Punctuated::new(),
402            Fields::Named(FieldsNamed { named: fields, .. })
403            | Fields::Unnamed(FieldsUnnamed {
404                unnamed: fields, ..
405            }) => fields,
406        };
407        if variant_fields.len() != 1 {
408            bail!("Oneof enum variants must have a single field");
409        }
410        match Field::new_oneof(attrs)? {
411            Some(field) => fields.push((variant_ident, field)),
412            None => bail!("invalid oneof variant: oneof variants may not be ignored"),
413        }
414    }
415
416    // Oneof variants cannot be oneofs themselves, so it's impossible to have a field with multiple
417    // tags.
418    assert!(fields.iter().all(|(_, field)| field.tags().len() == 1));
419
420    if let Some(duplicate_tag) = fields
421        .iter()
422        .flat_map(|(_, field)| field.tags())
423        .duplicates()
424        .next()
425    {
426        bail!(
427            "invalid oneof {}: multiple variants have tag {}",
428            ident,
429            duplicate_tag
430        );
431    }
432
433    let encode = fields.iter().map(|(variant_ident, field)| {
434        let encode = field.encode(quote!(*value));
435        quote!(#ident::#variant_ident(ref value) => { #encode })
436    });
437
438    let merge = fields.iter().map(|(variant_ident, field)| {
439        let tag = field.tags()[0];
440        let merge = field.merge(quote!(value));
441        quote! {
442            #tag => {
443                match field {
444                    ::core::option::Option::Some(#ident::#variant_ident(ref mut value)) => {
445                        #merge
446                    },
447                    _ => {
448                        let mut owned_value = ::core::default::Default::default();
449                        let value = &mut owned_value;
450                        #merge.map(|_| *field = ::core::option::Option::Some(#ident::#variant_ident(owned_value)))
451                    },
452                }
453            }
454        }
455    });
456
457    let encoded_len = fields.iter().map(|(variant_ident, field)| {
458        let encoded_len = field.encoded_len(quote!(*value));
459        quote!(#ident::#variant_ident(ref value) => #encoded_len)
460    });
461
462    let expanded = quote! {
463        impl #impl_generics #ident #ty_generics #where_clause {
464            /// Encodes the message to a buffer.
465            pub fn encode(&self, buf: &mut impl ::prost::bytes::BufMut) {
466                match *self {
467                    #(#encode,)*
468                }
469            }
470
471            /// Decodes an instance of the message from a buffer, and merges it into self.
472            pub fn merge(
473                field: &mut ::core::option::Option<#ident #ty_generics>,
474                tag: u32,
475                wire_type: ::prost::encoding::wire_type::WireType,
476                buf: &mut impl ::prost::bytes::Buf,
477                ctx: ::prost::encoding::DecodeContext,
478            ) -> ::core::result::Result<(), ::prost::DecodeError>
479            {
480                match tag {
481                    #(#merge,)*
482                    _ => unreachable!(concat!("invalid ", stringify!(#ident), " tag: {}"), tag),
483                }
484            }
485
486            /// Returns the encoded length of the message without a length delimiter.
487            #[inline]
488            pub fn encoded_len(&self) -> usize {
489                match *self {
490                    #(#encoded_len,)*
491                }
492            }
493        }
494
495    };
496    let expanded = if skip_debug {
497        expanded
498    } else {
499        let debug = fields.iter().map(|(variant_ident, field)| {
500            let wrapper = field.debug(quote!(*value));
501            quote!(#ident::#variant_ident(ref value) => {
502                let wrapper = #wrapper;
503                f.debug_tuple(stringify!(#variant_ident))
504                    .field(&wrapper)
505                    .finish()
506            })
507        });
508        quote! {
509            #expanded
510
511            impl #impl_generics ::core::fmt::Debug for #ident #ty_generics #where_clause {
512                fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
513                    match *self {
514                        #(#debug,)*
515                    }
516                }
517            }
518        }
519    };
520
521    Ok(expanded)
522}
523
524#[proc_macro_derive(Oneof, attributes(prost))]
525pub fn oneof(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
526    try_oneof(input.into()).unwrap().into()
527}
528
529#[cfg(test)]
530mod test {
531    use crate::{try_message, try_oneof};
532    use quote::quote;
533
534    #[test]
535    fn test_rejects_colliding_message_fields() {
536        let output = try_message(quote!(
537            struct Invalid {
538                #[prost(bool, tag = "1")]
539                a: bool,
540                #[prost(oneof = "super::Whatever", tags = "4, 5, 1")]
541                b: Option<super::Whatever>,
542            }
543        ));
544        assert_eq!(
545            output
546                .expect_err("did not reject colliding message fields")
547                .to_string(),
548            "message Invalid has multiple fields with tag 1"
549        );
550    }
551
552    #[test]
553    fn test_rejects_colliding_oneof_variants() {
554        let output = try_oneof(quote!(
555            pub enum Invalid {
556                #[prost(bool, tag = "1")]
557                A(bool),
558                #[prost(bool, tag = "3")]
559                B(bool),
560                #[prost(bool, tag = "1")]
561                C(bool),
562            }
563        ));
564        assert_eq!(
565            output
566                .expect_err("did not reject colliding oneof variants")
567                .to_string(),
568            "invalid oneof Invalid: multiple variants have tag 1"
569        );
570    }
571
572    #[test]
573    fn test_rejects_multiple_tags_oneof_variant() {
574        let output = try_oneof(quote!(
575            enum What {
576                #[prost(bool, tag = "1", tag = "2")]
577                A(bool),
578            }
579        ));
580        assert_eq!(
581            output
582                .expect_err("did not reject multiple tags on oneof variant")
583                .to_string(),
584            "duplicate tag attributes: 1 and 2"
585        );
586
587        let output = try_oneof(quote!(
588            enum What {
589                #[prost(bool, tag = "3")]
590                #[prost(tag = "4")]
591                A(bool),
592            }
593        ));
594        assert!(output.is_err());
595        assert_eq!(
596            output
597                .expect_err("did not reject multiple tags on oneof variant")
598                .to_string(),
599            "duplicate tag attributes: 3 and 4"
600        );
601
602        let output = try_oneof(quote!(
603            enum What {
604                #[prost(bool, tags = "5,6")]
605                A(bool),
606            }
607        ));
608        assert!(output.is_err());
609        assert_eq!(
610            output
611                .expect_err("did not reject multiple tags on oneof variant")
612                .to_string(),
613            "unknown attribute(s): #[prost(tags = \"5,6\")]"
614        );
615    }
616}