1#![doc(html_root_url = "https://docs.rs/prost-derive/0.13.5")]
2#![recursion_limit = "4096"]
4
5extern crate alloc;
6extern crate proc_macro;
7
8use anyhow::{bail, Error};
9use itertools::Itertools;
10use proc_macro2::{Span, TokenStream};
11use quote::quote;
12use syn::{
13 punctuated::Punctuated, Data, DataEnum, DataStruct, DeriveInput, Expr, Fields, FieldsNamed,
14 FieldsUnnamed, Ident, Index, Variant,
15};
16
17mod field;
18use crate::field::Field;
19
20fn try_message(input: TokenStream) -> Result<TokenStream, Error> {
21 let input: DeriveInput = syn::parse2(input)?;
22
23 let ident = input.ident;
24
25 syn::custom_keyword!(skip_debug);
26 let skip_debug = input
27 .attrs
28 .into_iter()
29 .any(|a| a.path().is_ident("prost") && a.parse_args::<skip_debug>().is_ok());
30
31 let variant_data = match input.data {
32 Data::Struct(variant_data) => variant_data,
33 Data::Enum(..) => bail!("Message can not be derived for an enum"),
34 Data::Union(..) => bail!("Message can not be derived for a union"),
35 };
36
37 let generics = &input.generics;
38 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
39
40 let (is_struct, fields) = match variant_data {
41 DataStruct {
42 fields: Fields::Named(FieldsNamed { named: fields, .. }),
43 ..
44 } => (true, fields.into_iter().collect()),
45 DataStruct {
46 fields:
47 Fields::Unnamed(FieldsUnnamed {
48 unnamed: fields, ..
49 }),
50 ..
51 } => (false, fields.into_iter().collect()),
52 DataStruct {
53 fields: Fields::Unit,
54 ..
55 } => (false, Vec::new()),
56 };
57
58 let mut next_tag: u32 = 1;
59 let mut fields = fields
60 .into_iter()
61 .enumerate()
62 .flat_map(|(i, field)| {
63 let field_ident = field.ident.map(|x| quote!(#x)).unwrap_or_else(|| {
64 let index = Index {
65 index: i as u32,
66 span: Span::call_site(),
67 };
68 quote!(#index)
69 });
70 match Field::new(field.attrs, Some(next_tag)) {
71 Ok(Some(field)) => {
72 next_tag = field.tags().iter().max().map(|t| t + 1).unwrap_or(next_tag);
73 Some(Ok((field_ident, field)))
74 }
75 Ok(None) => None,
76 Err(err) => Some(Err(
77 err.context(format!("invalid message field {}.{}", ident, field_ident))
78 )),
79 }
80 })
81 .collect::<Result<Vec<_>, _>>()?;
82
83 let unsorted_fields = fields.clone();
85
86 fields.sort_by_key(|(_, field)| field.tags().into_iter().min().unwrap());
91 let fields = fields;
92
93 if let Some(duplicate_tag) = fields
94 .iter()
95 .flat_map(|(_, field)| field.tags())
96 .duplicates()
97 .next()
98 {
99 bail!(
100 "message {} has multiple fields with tag {}",
101 ident,
102 duplicate_tag
103 )
104 };
105
106 let encoded_len = fields
107 .iter()
108 .map(|(field_ident, field)| field.encoded_len(quote!(self.#field_ident)));
109
110 let encode = fields
111 .iter()
112 .map(|(field_ident, field)| field.encode(quote!(self.#field_ident)));
113
114 let merge = fields.iter().map(|(field_ident, field)| {
115 let merge = field.merge(quote!(value));
116 let tags = field.tags().into_iter().map(|tag| quote!(#tag));
117 let tags = Itertools::intersperse(tags, quote!(|));
118
119 quote! {
120 #(#tags)* => {
121 let mut value = &mut self.#field_ident;
122 #merge.map_err(|mut error| {
123 error.push(STRUCT_NAME, stringify!(#field_ident));
124 error
125 })
126 },
127 }
128 });
129
130 let struct_name = if fields.is_empty() {
131 quote!()
132 } else {
133 quote!(
134 const STRUCT_NAME: &'static str = stringify!(#ident);
135 )
136 };
137
138 let clear = fields
139 .iter()
140 .map(|(field_ident, field)| field.clear(quote!(self.#field_ident)));
141
142 let default = if is_struct {
143 let default = fields.iter().map(|(field_ident, field)| {
144 let value = field.default();
145 quote!(#field_ident: #value,)
146 });
147 quote! {#ident {
148 #(#default)*
149 }}
150 } else {
151 let default = fields.iter().map(|(_, field)| {
152 let value = field.default();
153 quote!(#value,)
154 });
155 quote! {#ident (
156 #(#default)*
157 )}
158 };
159
160 let methods = fields
161 .iter()
162 .flat_map(|(field_ident, field)| field.methods(field_ident))
163 .collect::<Vec<_>>();
164 let methods = if methods.is_empty() {
165 quote!()
166 } else {
167 quote! {
168 #[allow(dead_code)]
169 impl #impl_generics #ident #ty_generics #where_clause {
170 #(#methods)*
171 }
172 }
173 };
174
175 let expanded = quote! {
176 impl #impl_generics ::prost::Message for #ident #ty_generics #where_clause {
177 #[allow(unused_variables)]
178 fn encode_raw(&self, buf: &mut impl ::prost::bytes::BufMut) {
179 #(#encode)*
180 }
181
182 #[allow(unused_variables)]
183 fn merge_field(
184 &mut self,
185 tag: u32,
186 wire_type: ::prost::encoding::wire_type::WireType,
187 buf: &mut impl ::prost::bytes::Buf,
188 ctx: ::prost::encoding::DecodeContext,
189 ) -> ::core::result::Result<(), ::prost::DecodeError>
190 {
191 #struct_name
192 match tag {
193 #(#merge)*
194 _ => ::prost::encoding::skip_field(wire_type, tag, buf, ctx),
195 }
196 }
197
198 #[inline]
199 fn encoded_len(&self) -> usize {
200 0 #(+ #encoded_len)*
201 }
202
203 fn clear(&mut self) {
204 #(#clear;)*
205 }
206 }
207
208 impl #impl_generics ::core::default::Default for #ident #ty_generics #where_clause {
209 fn default() -> Self {
210 #default
211 }
212 }
213 };
214 let expanded = if skip_debug {
215 expanded
216 } else {
217 let debugs = unsorted_fields.iter().map(|(field_ident, field)| {
218 let wrapper = field.debug(quote!(self.#field_ident));
219 let call = if is_struct {
220 quote!(builder.field(stringify!(#field_ident), &wrapper))
221 } else {
222 quote!(builder.field(&wrapper))
223 };
224 quote! {
225 let builder = {
226 let wrapper = #wrapper;
227 #call
228 };
229 }
230 });
231 let debug_builder = if is_struct {
232 quote!(f.debug_struct(stringify!(#ident)))
233 } else {
234 quote!(f.debug_tuple(stringify!(#ident)))
235 };
236 quote! {
237 #expanded
238
239 impl #impl_generics ::core::fmt::Debug for #ident #ty_generics #where_clause {
240 fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
241 let mut builder = #debug_builder;
242 #(#debugs;)*
243 builder.finish()
244 }
245 }
246 }
247 };
248
249 let expanded = quote! {
250 #expanded
251
252 #methods
253 };
254
255 Ok(expanded)
256}
257
258#[proc_macro_derive(Message, attributes(prost))]
259pub fn message(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
260 try_message(input.into()).unwrap().into()
261}
262
263fn try_enumeration(input: TokenStream) -> Result<TokenStream, Error> {
264 let input: DeriveInput = syn::parse2(input)?;
265 let ident = input.ident;
266
267 let generics = &input.generics;
268 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
269
270 let punctuated_variants = match input.data {
271 Data::Enum(DataEnum { variants, .. }) => variants,
272 Data::Struct(_) => bail!("Enumeration can not be derived for a struct"),
273 Data::Union(..) => bail!("Enumeration can not be derived for a union"),
274 };
275
276 let mut variants: Vec<(Ident, Expr)> = Vec::new();
278 for Variant {
279 ident,
280 fields,
281 discriminant,
282 ..
283 } in punctuated_variants
284 {
285 match fields {
286 Fields::Unit => (),
287 Fields::Named(_) | Fields::Unnamed(_) => {
288 bail!("Enumeration variants may not have fields")
289 }
290 }
291
292 match discriminant {
293 Some((_, expr)) => variants.push((ident, expr)),
294 None => bail!("Enumeration variants must have a discriminant"),
295 }
296 }
297
298 if variants.is_empty() {
299 panic!("Enumeration must have at least one variant");
300 }
301
302 let default = variants[0].0.clone();
303
304 let is_valid = variants.iter().map(|(_, value)| quote!(#value => true));
305 let from = variants
306 .iter()
307 .map(|(variant, value)| quote!(#value => ::core::option::Option::Some(#ident::#variant)));
308
309 let try_from = variants
310 .iter()
311 .map(|(variant, value)| quote!(#value => ::core::result::Result::Ok(#ident::#variant)));
312
313 let is_valid_doc = format!("Returns `true` if `value` is a variant of `{}`.", ident);
314 let from_i32_doc = format!(
315 "Converts an `i32` to a `{}`, or `None` if `value` is not a valid variant.",
316 ident
317 );
318
319 let expanded = quote! {
320 impl #impl_generics #ident #ty_generics #where_clause {
321 #[doc=#is_valid_doc]
322 pub fn is_valid(value: i32) -> bool {
323 match value {
324 #(#is_valid,)*
325 _ => false,
326 }
327 }
328
329 #[deprecated = "Use the TryFrom<i32> implementation instead"]
330 #[doc=#from_i32_doc]
331 pub fn from_i32(value: i32) -> ::core::option::Option<#ident> {
332 match value {
333 #(#from,)*
334 _ => ::core::option::Option::None,
335 }
336 }
337 }
338
339 impl #impl_generics ::core::default::Default for #ident #ty_generics #where_clause {
340 fn default() -> #ident {
341 #ident::#default
342 }
343 }
344
345 impl #impl_generics ::core::convert::From::<#ident> for i32 #ty_generics #where_clause {
346 fn from(value: #ident) -> i32 {
347 value as i32
348 }
349 }
350
351 impl #impl_generics ::core::convert::TryFrom::<i32> for #ident #ty_generics #where_clause {
352 type Error = ::prost::UnknownEnumValue;
353
354 fn try_from(value: i32) -> ::core::result::Result<#ident, ::prost::UnknownEnumValue> {
355 match value {
356 #(#try_from,)*
357 _ => ::core::result::Result::Err(::prost::UnknownEnumValue(value)),
358 }
359 }
360 }
361 };
362
363 Ok(expanded)
364}
365
366#[proc_macro_derive(Enumeration, attributes(prost))]
367pub fn enumeration(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
368 try_enumeration(input.into()).unwrap().into()
369}
370
371fn try_oneof(input: TokenStream) -> Result<TokenStream, Error> {
372 let input: DeriveInput = syn::parse2(input)?;
373
374 let ident = input.ident;
375
376 syn::custom_keyword!(skip_debug);
377 let skip_debug = input
378 .attrs
379 .into_iter()
380 .any(|a| a.path().is_ident("prost") && a.parse_args::<skip_debug>().is_ok());
381
382 let variants = match input.data {
383 Data::Enum(DataEnum { variants, .. }) => variants,
384 Data::Struct(..) => bail!("Oneof can not be derived for a struct"),
385 Data::Union(..) => bail!("Oneof can not be derived for a union"),
386 };
387
388 let generics = &input.generics;
389 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
390
391 let mut fields: Vec<(Ident, Field)> = Vec::new();
393 for Variant {
394 attrs,
395 ident: variant_ident,
396 fields: variant_fields,
397 ..
398 } in variants
399 {
400 let variant_fields = match variant_fields {
401 Fields::Unit => Punctuated::new(),
402 Fields::Named(FieldsNamed { named: fields, .. })
403 | Fields::Unnamed(FieldsUnnamed {
404 unnamed: fields, ..
405 }) => fields,
406 };
407 if variant_fields.len() != 1 {
408 bail!("Oneof enum variants must have a single field");
409 }
410 match Field::new_oneof(attrs)? {
411 Some(field) => fields.push((variant_ident, field)),
412 None => bail!("invalid oneof variant: oneof variants may not be ignored"),
413 }
414 }
415
416 assert!(fields.iter().all(|(_, field)| field.tags().len() == 1));
419
420 if let Some(duplicate_tag) = fields
421 .iter()
422 .flat_map(|(_, field)| field.tags())
423 .duplicates()
424 .next()
425 {
426 bail!(
427 "invalid oneof {}: multiple variants have tag {}",
428 ident,
429 duplicate_tag
430 );
431 }
432
433 let encode = fields.iter().map(|(variant_ident, field)| {
434 let encode = field.encode(quote!(*value));
435 quote!(#ident::#variant_ident(ref value) => { #encode })
436 });
437
438 let merge = fields.iter().map(|(variant_ident, field)| {
439 let tag = field.tags()[0];
440 let merge = field.merge(quote!(value));
441 quote! {
442 #tag => {
443 match field {
444 ::core::option::Option::Some(#ident::#variant_ident(ref mut value)) => {
445 #merge
446 },
447 _ => {
448 let mut owned_value = ::core::default::Default::default();
449 let value = &mut owned_value;
450 #merge.map(|_| *field = ::core::option::Option::Some(#ident::#variant_ident(owned_value)))
451 },
452 }
453 }
454 }
455 });
456
457 let encoded_len = fields.iter().map(|(variant_ident, field)| {
458 let encoded_len = field.encoded_len(quote!(*value));
459 quote!(#ident::#variant_ident(ref value) => #encoded_len)
460 });
461
462 let expanded = quote! {
463 impl #impl_generics #ident #ty_generics #where_clause {
464 pub fn encode(&self, buf: &mut impl ::prost::bytes::BufMut) {
466 match *self {
467 #(#encode,)*
468 }
469 }
470
471 pub fn merge(
473 field: &mut ::core::option::Option<#ident #ty_generics>,
474 tag: u32,
475 wire_type: ::prost::encoding::wire_type::WireType,
476 buf: &mut impl ::prost::bytes::Buf,
477 ctx: ::prost::encoding::DecodeContext,
478 ) -> ::core::result::Result<(), ::prost::DecodeError>
479 {
480 match tag {
481 #(#merge,)*
482 _ => unreachable!(concat!("invalid ", stringify!(#ident), " tag: {}"), tag),
483 }
484 }
485
486 #[inline]
488 pub fn encoded_len(&self) -> usize {
489 match *self {
490 #(#encoded_len,)*
491 }
492 }
493 }
494
495 };
496 let expanded = if skip_debug {
497 expanded
498 } else {
499 let debug = fields.iter().map(|(variant_ident, field)| {
500 let wrapper = field.debug(quote!(*value));
501 quote!(#ident::#variant_ident(ref value) => {
502 let wrapper = #wrapper;
503 f.debug_tuple(stringify!(#variant_ident))
504 .field(&wrapper)
505 .finish()
506 })
507 });
508 quote! {
509 #expanded
510
511 impl #impl_generics ::core::fmt::Debug for #ident #ty_generics #where_clause {
512 fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
513 match *self {
514 #(#debug,)*
515 }
516 }
517 }
518 }
519 };
520
521 Ok(expanded)
522}
523
524#[proc_macro_derive(Oneof, attributes(prost))]
525pub fn oneof(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
526 try_oneof(input.into()).unwrap().into()
527}
528
529#[cfg(test)]
530mod test {
531 use crate::{try_message, try_oneof};
532 use quote::quote;
533
534 #[test]
535 fn test_rejects_colliding_message_fields() {
536 let output = try_message(quote!(
537 struct Invalid {
538 #[prost(bool, tag = "1")]
539 a: bool,
540 #[prost(oneof = "super::Whatever", tags = "4, 5, 1")]
541 b: Option<super::Whatever>,
542 }
543 ));
544 assert_eq!(
545 output
546 .expect_err("did not reject colliding message fields")
547 .to_string(),
548 "message Invalid has multiple fields with tag 1"
549 );
550 }
551
552 #[test]
553 fn test_rejects_colliding_oneof_variants() {
554 let output = try_oneof(quote!(
555 pub enum Invalid {
556 #[prost(bool, tag = "1")]
557 A(bool),
558 #[prost(bool, tag = "3")]
559 B(bool),
560 #[prost(bool, tag = "1")]
561 C(bool),
562 }
563 ));
564 assert_eq!(
565 output
566 .expect_err("did not reject colliding oneof variants")
567 .to_string(),
568 "invalid oneof Invalid: multiple variants have tag 1"
569 );
570 }
571
572 #[test]
573 fn test_rejects_multiple_tags_oneof_variant() {
574 let output = try_oneof(quote!(
575 enum What {
576 #[prost(bool, tag = "1", tag = "2")]
577 A(bool),
578 }
579 ));
580 assert_eq!(
581 output
582 .expect_err("did not reject multiple tags on oneof variant")
583 .to_string(),
584 "duplicate tag attributes: 1 and 2"
585 );
586
587 let output = try_oneof(quote!(
588 enum What {
589 #[prost(bool, tag = "3")]
590 #[prost(tag = "4")]
591 A(bool),
592 }
593 ));
594 assert!(output.is_err());
595 assert_eq!(
596 output
597 .expect_err("did not reject multiple tags on oneof variant")
598 .to_string(),
599 "duplicate tag attributes: 3 and 4"
600 );
601
602 let output = try_oneof(quote!(
603 enum What {
604 #[prost(bool, tags = "5,6")]
605 A(bool),
606 }
607 ));
608 assert!(output.is_err());
609 assert_eq!(
610 output
611 .expect_err("did not reject multiple tags on oneof variant")
612 .to_string(),
613 "unknown attribute(s): #[prost(tags = \"5,6\")]"
614 );
615 }
616}