enum_kinds/
lib.rs

1#![doc = include_str!("../README.md")]
2
3extern crate proc_macro;
4extern crate proc_macro2;
5extern crate quote;
6#[macro_use]
7extern crate syn;
8
9use proc_macro2::TokenStream;
10use quote::quote;
11use std::collections::HashSet;
12use syn::punctuated::Punctuated;
13use syn::{
14    Data, DataEnum, DeriveInput, Fields, GenericParam, Lifetime, LifetimeDef, Meta, MetaList,
15    MetaNameValue, NestedMeta, Path,
16};
17
18#[proc_macro_derive(EnumKind, attributes(enum_kind))]
19pub fn enum_kind(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
20    let ast = syn::parse(input).expect("#[derive(EnumKind)] failed to parse input");
21    let (name, traits) = get_enum_specification(&ast);
22    let enum_ = create_kind_enum(&ast, &name, traits);
23    let impl_ = create_impl(&ast, &name);
24    let code = quote! {
25        #enum_
26        #impl_
27    };
28    proc_macro::TokenStream::from(code)
29}
30
31fn find_attribute(
32    definition: &DeriveInput,
33    name: &str,
34) -> Option<Punctuated<NestedMeta, syn::token::Comma>> {
35    for attr in definition.attrs.iter() {
36        match attr.parse_meta() {
37            Ok(Meta::List(MetaList {
38                ref path,
39                ref nested,
40                ..
41            })) if path.is_ident(name) => return Some(nested.clone()),
42            _ => continue,
43        }
44    }
45    None
46}
47
48fn get_enum_specification(definition: &DeriveInput) -> (Path, Vec<NestedMeta>) {
49    let params = find_attribute(definition, "enum_kind")
50        .expect("#[derive(EnumKind)] requires an associated enum_kind attribute to be specified");
51    let mut iter = params.iter();
52    if let Some(&NestedMeta::Meta(Meta::Path(ref path))) = iter.next() {
53        return (path.to_owned(), iter.cloned().collect());
54    } else {
55        panic!("#[enum_kind(NAME)] attribute requires NAME to be specified");
56    }
57}
58
59fn has_docs(traits: &[NestedMeta]) -> bool {
60    traits.iter().any(|attr| {
61        if let NestedMeta::Meta(Meta::NameValue(MetaNameValue { path, .. })) = attr {
62            path.is_ident("doc")
63        } else {
64            false
65        }
66    })
67}
68
69fn create_kind_enum(
70    definition: &DeriveInput,
71    kind_ident: &Path,
72    traits: Vec<NestedMeta>,
73) -> TokenStream {
74    let variant_idents = match &definition.data {
75        &Data::Enum(DataEnum { ref variants, .. }) => variants.iter().map(|ref v| v.ident.clone()),
76        _ => {
77            panic!("#[derive(EnumKind)] is only allowed for enums");
78        }
79    };
80    let visibility = &definition.vis;
81    let docs_attr = if !has_docs(traits.as_ref()) {
82        quote! {#[allow(missing_docs)]}
83    } else {
84        quote! {}
85    };
86    let code = quote! {
87        #[derive(Debug, Clone, Copy, PartialEq, Eq)]
88        #[allow(dead_code)]
89        #docs_attr
90        #( #[#traits] )*
91        #visibility enum #kind_ident {
92            #(#variant_idents),*
93        }
94    };
95    TokenStream::from(code)
96}
97
98fn is_uninhabited_enum(definition: &DeriveInput) -> bool {
99    if let Data::Enum(ref data) = definition.data {
100        return data.variants.len() == 0;
101    }
102    return false;
103}
104
105fn create_impl(definition: &DeriveInput, kind_ident: &Path) -> TokenStream {
106    let (_, ty_generics, where_clause) = definition.generics.split_for_impl();
107    let ident = &definition.ident;
108
109    let arms = match &definition.data {
110        &Data::Enum(DataEnum { ref variants, .. }) => variants.iter().map(|ref v| {
111            let variant = &v.ident;
112            match v.fields {
113                Fields::Unit => quote! {
114                    &#ident::#variant => #kind_ident::#variant,
115                },
116                Fields::Unnamed(_) => quote! {
117                    &#ident::#variant(..) => #kind_ident::#variant,
118                },
119                Fields::Named(_) => quote! {
120                    &#ident::#variant{..} => #kind_ident::#variant,
121                },
122            }
123        }),
124        _ => {
125            panic!("#[derive(EnumKind)] is only allowed for enums");
126        }
127    };
128
129    let trait_: Path = if cfg!(feature = "no-stdlib") {
130        parse_quote!(::core::convert::From)
131    } else {
132        parse_quote!(::std::convert::From)
133    };
134
135    let mut counter: u32 = 1;
136    let used: HashSet<Lifetime> = definition
137        .generics
138        .lifetimes()
139        .map(|ld| ld.lifetime.clone())
140        .collect();
141    let a = loop {
142        let lifetime: Lifetime = syn::parse_str(&format!("'__enum_kinds{}", counter)).unwrap();
143        if !used.contains(&lifetime) {
144            break LifetimeDef::new(lifetime);
145        }
146        counter += 1;
147    };
148
149    let mut generics = definition.generics.clone();
150    generics.params.insert(0, GenericParam::Lifetime(a.clone()));
151    let (impl_generics, _, _) = generics.split_for_impl();
152
153    let impl_ = if is_uninhabited_enum(definition) {
154        quote! {
155            unreachable!();
156        }
157    } else {
158        quote! {
159            match _value {
160                #(#arms)*
161            }
162        }
163    };
164
165    let tokens = quote! {
166        #[automatically_derived]
167        #[allow(unused_attributes)]
168        impl #impl_generics #trait_<&#a #ident#ty_generics> for #kind_ident #where_clause {
169            fn from(_value: &#a #ident#ty_generics) -> Self {
170                #impl_
171            }
172        }
173
174        #[automatically_derived]
175        #[allow(unused_attributes)]
176        impl #impl_generics #trait_<#ident#ty_generics> for #kind_ident #where_clause {
177            fn from(value: #ident#ty_generics) -> Self {
178                #kind_ident::from(&value)
179            }
180        }
181    };
182    TokenStream::from(tokens)
183}