enumflags2_derive/
lib.rs

1#![recursion_limit = "2048"]
2extern crate proc_macro;
3#[macro_use]
4extern crate quote;
5
6use proc_macro2::{Span, TokenStream};
7use std::convert::TryFrom;
8use syn::{
9    parse::{Parse, ParseStream},
10    parse_macro_input,
11    spanned::Spanned,
12    Expr, Ident, Item, ItemEnum, Token, Variant,
13};
14
15struct Flag<'a> {
16    name: Ident,
17    span: Span,
18    value: FlagValue<'a>,
19}
20
21enum FlagValue<'a> {
22    Literal(u128),
23    Deferred,
24    Inferred(&'a mut Variant),
25}
26
27impl FlagValue<'_> {
28    fn is_inferred(&self) -> bool {
29        matches!(self, FlagValue::Inferred(_))
30    }
31}
32
33struct Parameters {
34    default: Vec<Ident>,
35}
36
37impl Parse for Parameters {
38    fn parse(input: ParseStream) -> syn::parse::Result<Self> {
39        if input.is_empty() {
40            return Ok(Parameters { default: vec![] });
41        }
42
43        input.parse::<Token![default]>()?;
44        input.parse::<Token![=]>()?;
45        let mut default = vec![input.parse()?];
46        while !input.is_empty() {
47            input.parse::<Token![|]>()?;
48            default.push(input.parse()?);
49        }
50
51        Ok(Parameters { default })
52    }
53}
54
55#[proc_macro_attribute]
56pub fn bitflags_internal(
57    attr: proc_macro::TokenStream,
58    input: proc_macro::TokenStream,
59) -> proc_macro::TokenStream {
60    let Parameters { default } = parse_macro_input!(attr as Parameters);
61    let mut ast = parse_macro_input!(input as Item);
62    let output = match ast {
63        Item::Enum(ref mut item_enum) => gen_enumflags(item_enum, default),
64        _ => Err(syn::Error::new_spanned(
65            &ast,
66            "#[bitflags] requires an enum",
67        )),
68    };
69
70    output
71        .unwrap_or_else(|err| {
72            let error = err.to_compile_error();
73            quote! {
74                #ast
75                #error
76            }
77        })
78        .into()
79}
80
81/// Try to evaluate the expression given.
82fn fold_expr(expr: &syn::Expr) -> Option<u128> {
83    match expr {
84        Expr::Lit(ref expr_lit) => match expr_lit.lit {
85            syn::Lit::Int(ref lit_int) => lit_int.base10_parse().ok(),
86            _ => None,
87        },
88        Expr::Binary(ref expr_binary) => {
89            let l = fold_expr(&expr_binary.left)?;
90            let r = fold_expr(&expr_binary.right)?;
91            match &expr_binary.op {
92                syn::BinOp::Shl(_) => u32::try_from(r).ok().and_then(|r| l.checked_shl(r)),
93                _ => None,
94            }
95        }
96        Expr::Paren(syn::ExprParen { expr, .. }) | Expr::Group(syn::ExprGroup { expr, .. }) => {
97            fold_expr(expr)
98        }
99        _ => None,
100    }
101}
102
103fn collect_flags<'a>(
104    variants: impl Iterator<Item = &'a mut Variant>,
105) -> Result<Vec<Flag<'a>>, syn::Error> {
106    variants
107        .map(|variant| {
108            if !matches!(variant.fields, syn::Fields::Unit) {
109                return Err(syn::Error::new_spanned(
110                    &variant.fields,
111                    "Bitflag variants cannot contain additional data",
112                ));
113            }
114
115            let name = variant.ident.clone();
116            let span = variant.span();
117            let value = if let Some(ref expr) = variant.discriminant {
118                if let Some(n) = fold_expr(&expr.1) {
119                    FlagValue::Literal(n)
120                } else {
121                    FlagValue::Deferred
122                }
123            } else {
124                FlagValue::Inferred(variant)
125            };
126
127            Ok(Flag { name, span, value })
128        })
129        .collect()
130}
131
132fn inferred_value(type_name: &Ident, previous_variants: &[Ident], repr: &Ident) -> Expr {
133    let tokens = if previous_variants.is_empty() {
134        quote!(1)
135    } else {
136        quote!(::enumflags2::_internal::next_bit(
137                #(#type_name::#previous_variants as u128)|*
138        ) as #repr)
139    };
140
141    syn::parse2(tokens).expect("couldn't parse inferred value")
142}
143
144fn infer_values(flags: &mut [Flag], type_name: &Ident, repr: &Ident) {
145    let mut previous_variants: Vec<Ident> = flags
146        .iter()
147        .filter(|flag| !flag.value.is_inferred())
148        .map(|flag| flag.name.clone())
149        .collect();
150
151    for flag in flags {
152        if let FlagValue::Inferred(ref mut variant) = flag.value {
153            variant.discriminant = Some((
154                <Token![=]>::default(),
155                inferred_value(type_name, &previous_variants, repr),
156            ));
157            previous_variants.push(flag.name.clone());
158        }
159    }
160}
161
162/// Given a list of attributes, find the `repr`, if any, and return the integer
163/// type specified.
164fn extract_repr(attrs: &[syn::Attribute]) -> Result<Option<Ident>, syn::Error> {
165    let mut res = None;
166    for attr in attrs {
167        if attr.path().is_ident("repr") {
168            attr.parse_nested_meta(|meta| {
169                if let Some(ident) = meta.path.get_ident() {
170                    res = Some(ident.clone());
171                }
172                Ok(())
173            })?;
174        }
175    }
176    Ok(res)
177}
178
179/// Check the repr and return the number of bits available
180fn type_bits(ty: &Ident) -> Result<u8, syn::Error> {
181    // This would be so much easier if we could just match on an Ident...
182    if ty == "usize" {
183        Err(syn::Error::new_spanned(
184            ty,
185            "#[repr(usize)] is not supported. Use u32 or u64 instead.",
186        ))
187    } else if ty == "i8"
188        || ty == "i16"
189        || ty == "i32"
190        || ty == "i64"
191        || ty == "i128"
192        || ty == "isize"
193    {
194        Err(syn::Error::new_spanned(
195            ty,
196            "Signed types in a repr are not supported.",
197        ))
198    } else if ty == "u8" {
199        Ok(8)
200    } else if ty == "u16" {
201        Ok(16)
202    } else if ty == "u32" {
203        Ok(32)
204    } else if ty == "u64" {
205        Ok(64)
206    } else if ty == "u128" {
207        Ok(128)
208    } else {
209        Err(syn::Error::new_spanned(
210            ty,
211            "repr must be an integer type for #[bitflags].",
212        ))
213    }
214}
215
216/// Returns deferred checks
217fn check_flag(type_name: &Ident, flag: &Flag, bits: u8) -> Result<Option<TokenStream>, syn::Error> {
218    use FlagValue::*;
219    match flag.value {
220        Literal(n) => {
221            if !n.is_power_of_two() {
222                Err(syn::Error::new(
223                    flag.span,
224                    "Flags must have exactly one set bit",
225                ))
226            } else if bits < 128 && n >= 1 << bits {
227                Err(syn::Error::new(
228                    flag.span,
229                    format!("Flag value out of range for u{}", bits),
230                ))
231            } else {
232                Ok(None)
233            }
234        }
235        Inferred(_) => Ok(None),
236        Deferred => {
237            let variant_name = &flag.name;
238            Ok(Some(quote_spanned!(flag.span =>
239                const _:
240                    <<[(); (
241                        (#type_name::#variant_name as u128).is_power_of_two()
242                    ) as usize] as enumflags2::_internal::AssertionHelper>
243                        ::Status as enumflags2::_internal::ExactlyOneBitSet>::X
244                    = ();
245            )))
246        }
247    }
248}
249
250fn gen_enumflags(ast: &mut ItemEnum, default: Vec<Ident>) -> Result<TokenStream, syn::Error> {
251    let ident = &ast.ident;
252
253    let span = Span::call_site();
254
255    let repr = extract_repr(&ast.attrs)?
256        .ok_or_else(|| syn::Error::new_spanned(ident,
257                        "repr attribute missing. Add #[repr(u64)] or a similar attribute to specify the size of the bitfield."))?;
258    let bits = type_bits(&repr)?;
259
260    let mut variants = collect_flags(ast.variants.iter_mut())?;
261    let deferred = variants
262        .iter()
263        .flat_map(|variant| check_flag(ident, variant, bits).transpose())
264        .collect::<Result<Vec<_>, _>>()?;
265
266    infer_values(&mut variants, ident, &repr);
267
268    if (bits as usize) < variants.len() {
269        return Err(syn::Error::new_spanned(
270            &repr,
271            format!("Not enough bits for {} flags", variants.len()),
272        ));
273    }
274
275    let std = quote_spanned!(span => ::enumflags2::_internal::core);
276    let variant_names = ast.variants.iter().map(|v| &v.ident).collect::<Vec<_>>();
277
278    Ok(quote_spanned! {
279        span =>
280            #ast
281            #(#deferred)*
282            impl #std::ops::Not for #ident {
283                type Output = ::enumflags2::BitFlags<Self>;
284                #[inline(always)]
285                fn not(self) -> Self::Output {
286                    use ::enumflags2::BitFlags;
287                    BitFlags::from_flag(self).not()
288                }
289            }
290
291            impl #std::ops::BitOr for #ident {
292                type Output = ::enumflags2::BitFlags<Self>;
293                #[inline(always)]
294                fn bitor(self, other: Self) -> Self::Output {
295                    use ::enumflags2::BitFlags;
296                    BitFlags::from_flag(self) | other
297                }
298            }
299
300            impl #std::ops::BitAnd for #ident {
301                type Output = ::enumflags2::BitFlags<Self>;
302                #[inline(always)]
303                fn bitand(self, other: Self) -> Self::Output {
304                    use ::enumflags2::BitFlags;
305                    BitFlags::from_flag(self) & other
306                }
307            }
308
309            impl #std::ops::BitXor for #ident {
310                type Output = ::enumflags2::BitFlags<Self>;
311                #[inline(always)]
312                fn bitxor(self, other: Self) -> Self::Output {
313                    use ::enumflags2::BitFlags;
314                    BitFlags::from_flag(self) ^ other
315                }
316            }
317
318            unsafe impl ::enumflags2::_internal::RawBitFlags for #ident {
319                type Numeric = #repr;
320
321                const EMPTY: Self::Numeric = 0;
322
323                const DEFAULT: Self::Numeric =
324                    0 #(| (Self::#default as #repr))*;
325
326                const ALL_BITS: Self::Numeric =
327                    0 #(| (Self::#variant_names as #repr))*;
328
329                const BITFLAGS_TYPE_NAME : &'static str =
330                    concat!("BitFlags<", stringify!(#ident), ">");
331
332                fn bits(self) -> Self::Numeric {
333                    self as #repr
334                }
335            }
336
337            impl ::enumflags2::BitFlag for #ident {}
338    })
339}