enum_iterator_derive/
lib.rs

1// Copyright (c) 2018-2022 Stephane Raux. Distributed under the 0BSD license.
2
3//! # Overview
4//! - [📦 crates.io](https://crates.io/crates/enum-iterator-derive)
5//! - [📖 Documentation](https://docs.rs/enum-iterator-derive)
6//! - [âš– 0BSD license](https://spdx.org/licenses/0BSD.html)
7//!
8//! Procedural macro to derive `Sequence`.
9//!
10//! See crate [`enum-iterator`](https://docs.rs/enum-iterator) for details.
11//!
12//! # Contribute
13//! All contributions shall be licensed under the [0BSD license](https://spdx.org/licenses/0BSD.html).
14
15#![recursion_limit = "128"]
16#![deny(warnings)]
17
18extern crate proc_macro;
19
20use proc_macro2::{Span, TokenStream};
21use quote::{quote, ToTokens};
22use std::{
23    collections::HashMap,
24    fmt::{self, Display},
25    iter::{self, once, repeat, repeat_n},
26};
27use syn::{
28    punctuated::Punctuated, token::Comma, DeriveInput, Field, Fields, Generics, Ident, Member,
29    Path, PathSegment, PredicateType, TraitBound, TraitBoundModifier, Type, TypeParamBound,
30    Variant, WhereClause, WherePredicate,
31};
32
33/// Derives `Sequence`.
34#[proc_macro_derive(Sequence, attributes(enum_iterator))]
35pub fn derive_sequence(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
36    derive(input)
37        .unwrap_or_else(|e| e.to_compile_error())
38        .into()
39}
40
41fn derive(input: proc_macro::TokenStream) -> Result<TokenStream, syn::Error> {
42    derive_for_ast(syn::parse(input)?)
43}
44
45#[derive(Debug)]
46struct DeriveOptions {
47    crate_path: Path,
48}
49
50impl DeriveOptions {
51    fn parse(attrs: &[syn::Attribute]) -> Result<Self, syn::Error> {
52        let mut crate_path = None;
53        attrs
54            .iter()
55            .filter(|attr| attr.path().is_ident("enum_iterator"))
56            .try_for_each(|attr| {
57                attr.parse_nested_meta(|meta| {
58                    if meta.path.is_ident("crate") {
59                        let path: Path = meta.value()?.parse()?;
60                        if crate_path.is_none() {
61                            crate_path = Some(path);
62                            Ok(())
63                        } else {
64                            Err(meta.error("duplicate crate key"))
65                        }
66                    } else {
67                        Err(meta.error(format!("unknown key {}", meta.path.to_token_stream())))
68                    }
69                })
70            })?;
71        Ok(Self {
72            crate_path: crate_path.unwrap_or_else(|| Path {
73                leading_colon: Some(Default::default()),
74                segments: [PathSegment::from(Ident::new(
75                    "enum_iterator",
76                    Span::call_site(),
77                ))]
78                .into_iter()
79                .collect(),
80            }),
81        })
82    }
83}
84
85fn derive_for_ast(ast: DeriveInput) -> Result<TokenStream, syn::Error> {
86    let ty = &ast.ident;
87    let generics = &ast.generics;
88    let options = DeriveOptions::parse(&ast.attrs)?;
89    match &ast.data {
90        syn::Data::Struct(s) => derive_for_struct(&options, ty, generics, &s.fields),
91        syn::Data::Enum(e) => derive_for_enum(&options, ty, generics, &e.variants),
92        syn::Data::Union(_) => Err(Error::UnsupportedUnion.with_tokens(&ast)),
93    }
94}
95
96fn derive_for_struct(
97    options: &DeriveOptions,
98    ty: &Ident,
99    generics: &Generics,
100    fields: &Fields,
101) -> Result<TokenStream, syn::Error> {
102    let crate_path = &options.crate_path;
103    let cardinality = tuple_cardinality(&options.crate_path, fields);
104    let first = init_value(&options.crate_path, ty, None, fields, Direction::Forward);
105    let last = init_value(&options.crate_path, ty, None, fields, Direction::Backward);
106    let next_body = advance_struct(&options.crate_path, ty, fields, Direction::Forward);
107    let previous_body = advance_struct(&options.crate_path, ty, fields, Direction::Backward);
108    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
109    let where_clause = if generics.params.is_empty() {
110        where_clause.cloned()
111    } else {
112        let mut clause = where_clause.cloned().unwrap_or_else(|| WhereClause {
113            where_token: Default::default(),
114            predicates: Default::default(),
115        });
116        clause.predicates.extend(
117            trait_bounds(
118                &options.crate_path,
119                group_type_requirements(fields.iter().rev().zip(tuple_type_requirements())),
120            )
121            .map(WherePredicate::Type),
122        );
123        Some(clause)
124    };
125    let tokens = quote! {
126        impl #impl_generics #crate_path::Sequence for #ty #ty_generics #where_clause {
127            #[allow(clippy::identity_op)]
128            const CARDINALITY: usize = #cardinality;
129
130            fn next(&self) -> ::core::option::Option<Self> {
131                #next_body
132            }
133
134            fn previous(&self) -> ::core::option::Option<Self> {
135                #previous_body
136            }
137
138            fn first() -> ::core::option::Option<Self> {
139                #first
140            }
141
142            fn last() -> ::core::option::Option<Self> {
143                #last
144            }
145        }
146    };
147    Ok(tokens)
148}
149
150fn derive_for_enum(
151    options: &DeriveOptions,
152    ty: &Ident,
153    generics: &Generics,
154    variants: &Punctuated<Variant, Comma>,
155) -> Result<TokenStream, syn::Error> {
156    let cardinality = enum_cardinality(&options.crate_path, variants);
157    let next_body = advance_enum(&options.crate_path, ty, variants, Direction::Forward);
158    let previous_body = advance_enum(&options.crate_path, ty, variants, Direction::Backward);
159    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
160    let where_clause = if generics.params.is_empty() {
161        where_clause.cloned()
162    } else {
163        let mut clause = where_clause.cloned().unwrap_or_else(|| WhereClause {
164            where_token: Default::default(),
165            predicates: Default::default(),
166        });
167        clause.predicates.extend(
168            trait_bounds(
169                &options.crate_path,
170                group_type_requirements(variants.iter().flat_map(|variant| {
171                    variant.fields.iter().rev().zip(tuple_type_requirements())
172                })),
173            )
174            .map(WherePredicate::Type),
175        );
176        Some(clause)
177    };
178    let next_variant_body = next_variant(&options.crate_path, ty, variants, Direction::Forward);
179    let previous_variant_body =
180        next_variant(&options.crate_path, ty, variants, Direction::Backward);
181    let (first, last) = if variants.is_empty() {
182        (
183            quote! { ::core::option::Option::None },
184            quote! { ::core::option::Option::None },
185        )
186    } else {
187        let last_index = variants.len() - 1;
188        (
189            quote! { next_variant(0) },
190            quote! { previous_variant(#last_index) },
191        )
192    };
193    let crate_path = &options.crate_path;
194    let tokens = quote! {
195        impl #impl_generics #crate_path::Sequence for #ty #ty_generics #where_clause {
196            #[allow(clippy::identity_op)]
197            const CARDINALITY: usize = #cardinality;
198
199            fn next(&self) -> ::core::option::Option<Self> {
200                #next_body
201            }
202
203            fn previous(&self) -> ::core::option::Option<Self> {
204                #previous_body
205            }
206
207            fn first() -> ::core::option::Option<Self> {
208                #first
209            }
210
211            fn last() -> ::core::option::Option<Self> {
212                #last
213            }
214        }
215
216        fn next_variant #impl_generics(
217            mut i: usize,
218        ) -> ::core::option::Option<#ty #ty_generics> #where_clause {
219            #next_variant_body
220        }
221
222        fn previous_variant #impl_generics(
223            mut i: usize,
224        ) -> ::core::option::Option<#ty #ty_generics> #where_clause {
225            #previous_variant_body
226        }
227    };
228    let tokens = quote! {
229        const _: () = { #tokens };
230    };
231    Ok(tokens)
232}
233
234fn enum_cardinality(crate_path: &Path, variants: &Punctuated<Variant, Comma>) -> TokenStream {
235    let terms = variants
236        .iter()
237        .map(|variant| tuple_cardinality(crate_path, &variant.fields));
238    quote! {
239        #((#terms) +)* 0
240    }
241}
242
243fn tuple_cardinality(crate_path: &Path, fields: &Fields) -> TokenStream {
244    let factors = fields.iter().map(|field| {
245        let ty = &field.ty;
246        quote! {
247            <#ty as #crate_path::Sequence>::CARDINALITY
248        }
249    });
250    quote! {
251        #(#factors *)* 1
252    }
253}
254
255fn field_id(field: &Field, index: usize) -> Member {
256    field
257        .ident
258        .clone()
259        .map_or_else(|| Member::from(index), Member::from)
260}
261
262fn init_value(
263    crate_path: &Path,
264    ty: &Ident,
265    variant: Option<&Ident>,
266    fields: &Fields,
267    direction: Direction,
268) -> TokenStream {
269    let id = variant.map_or_else(|| quote! { #ty }, |v| quote! { #ty::#v });
270    if fields.is_empty() {
271        quote! {
272            ::core::option::Option::Some(#id {})
273        }
274    } else {
275        let reset = direction.reset();
276        let initialization = repeat_n(quote! { #crate_path::Sequence::#reset() }, fields.len());
277        let assignments = field_assignments(fields);
278        let bindings = bindings().take(fields.len());
279        quote! {{
280            match (#(#initialization,)*) {
281                (#(::core::option::Option::Some(#bindings),)*) => {
282                    ::core::option::Option::Some(#id { #assignments })
283                }
284                _ => ::core::option::Option::None,
285            }
286        }}
287    }
288}
289
290fn next_variant(
291    crate_path: &Path,
292    ty: &Ident,
293    variants: &Punctuated<Variant, Comma>,
294    direction: Direction,
295) -> TokenStream {
296    let advance = match direction {
297        Direction::Forward => {
298            let last_index = variants.len().saturating_sub(1);
299            quote! {
300                if i >= #last_index { break ::core::option::Option::None; } else { i+= 1; }
301            }
302        }
303        Direction::Backward => quote! {
304            if i == 0 { break ::core::option::Option::None; } else { i -= 1; }
305        },
306    };
307    let arms = variants.iter().enumerate().map(|(i, v)| {
308        let id = &v.ident;
309        let init = init_value(crate_path, ty, Some(id), &v.fields, direction);
310        quote! {
311            #i => #init
312        }
313    });
314    quote! {
315        loop {
316            let next = match i {
317                #(#arms,)*
318                _ => ::core::option::Option::None,
319            };
320            match next {
321                ::core::option::Option::Some(_) => break next,
322                ::core::option::Option::None => #advance,
323            }
324        }
325    }
326}
327
328fn advance_struct(
329    crate_path: &Path,
330    ty: &Ident,
331    fields: &Fields,
332    direction: Direction,
333) -> TokenStream {
334    let assignments = field_assignments(fields);
335    let bindings = bindings().take(fields.len()).collect::<Vec<_>>();
336    let tuple = advance_tuple(crate_path, &bindings, direction);
337    quote! {
338        let #ty { #assignments } = self;
339        let (#(#bindings,)*) = #tuple?;
340        ::core::option::Option::Some(#ty { #assignments })
341    }
342}
343
344fn advance_enum(
345    crate_path: &Path,
346    ty: &Ident,
347    variants: &Punctuated<Variant, Comma>,
348    direction: Direction,
349) -> TokenStream {
350    let arms: Vec<_> = match direction {
351        Direction::Forward => variants
352            .iter()
353            .enumerate()
354            .map(|(i, variant)| advance_enum_arm(crate_path, ty, direction, i, variant))
355            .collect(),
356        Direction::Backward => variants
357            .iter()
358            .enumerate()
359            .rev()
360            .map(|(i, variant)| advance_enum_arm(crate_path, ty, direction, i, variant))
361            .collect(),
362    };
363    quote! {
364        match *self {
365            #(#arms,)*
366        }
367    }
368}
369
370fn advance_enum_arm(
371    crate_path: &Path,
372    ty: &Ident,
373    direction: Direction,
374    i: usize,
375    variant: &Variant,
376) -> TokenStream {
377    let next = match direction {
378        Direction::Forward => match i.checked_add(1) {
379            Some(next_i) => quote! { next_variant(#next_i) },
380            None => quote! { ::core::option::Option::None },
381        },
382        Direction::Backward => match i.checked_sub(1) {
383            Some(prev_i) => quote! { previous_variant(#prev_i) },
384            None => quote! { ::core::option::Option::None },
385        },
386    };
387    let id = &variant.ident;
388    if variant.fields.is_empty() {
389        quote! {
390            #ty::#id {} => #next
391        }
392    } else {
393        let destructuring = field_bindings(&variant.fields);
394        let assignments = field_assignments(&variant.fields);
395        let bindings = bindings().take(variant.fields.len()).collect::<Vec<_>>();
396        let tuple = advance_tuple(crate_path, &bindings, direction);
397        quote! {
398            #ty::#id { #destructuring } => {
399                let y = #tuple;
400                match y {
401                    ::core::option::Option::Some((#(#bindings,)*)) => {
402                        ::core::option::Option::Some(#ty::#id { #assignments })
403                    }
404                    ::core::option::Option::None => #next,
405                }
406            }
407        }
408    }
409}
410
411fn advance_tuple(crate_path: &Path, bindings: &[Ident], direction: Direction) -> TokenStream {
412    let advance = direction.advance();
413    let reset = direction.reset();
414    let rev_bindings = bindings.iter().rev().collect::<Vec<_>>();
415    let (rev_binding_head, rev_binding_tail) = match rev_bindings.split_first() {
416        Some((&head, tail)) => (Some(head), tail),
417        None => (None, &*rev_bindings),
418    };
419    let rev_binding_head = match rev_binding_head {
420        Some(head) => quote! {
421            let (#head, carry) = match #crate_path::Sequence::#advance(#head) {
422                ::core::option::Option::Some(#head) => (::core::option::Option::Some(#head), false),
423                ::core::option::Option::None => (#crate_path::Sequence::#reset(), true),
424            };
425        },
426        None => quote! {
427            let carry = true;
428        },
429    };
430    let body = quote! {
431        #rev_binding_head
432        #(
433            let (#rev_binding_tail, carry) = if carry {
434                match #crate_path::Sequence::#advance(#rev_binding_tail) {
435                    ::core::option::Option::Some(#rev_binding_tail) => {
436                        (::core::option::Option::Some(#rev_binding_tail), false)
437                    }
438                    ::core::option::Option::None => (#crate_path::Sequence::#reset(), true),
439                }
440            } else {
441                (
442                    ::core::option::Option::Some(::core::clone::Clone::clone(#rev_binding_tail)),
443                    false,
444                )
445            };
446        )*
447        if carry {
448            ::core::option::Option::None
449        } else {
450            match (#(#bindings,)*) {
451                (#(::core::option::Option::Some(#bindings),)*) => {
452                    ::core::option::Option::Some((#(#bindings,)*))
453                }
454                _ => ::core::option::Option::None,
455            }
456        }
457    };
458    quote! {
459        { #body }
460    }
461}
462
463fn field_assignments<'a, I>(fields: I) -> TokenStream
464where
465    I: IntoIterator<Item = &'a Field>,
466{
467    fields
468        .into_iter()
469        .enumerate()
470        .zip(bindings())
471        .map(|((i, field), binding)| {
472            let field_id = field_id(field, i);
473            quote! { #field_id: #binding, }
474        })
475        .collect()
476}
477
478fn field_bindings<'a, I>(fields: I) -> TokenStream
479where
480    I: IntoIterator<Item = &'a Field>,
481{
482    fields
483        .into_iter()
484        .enumerate()
485        .zip(bindings())
486        .map(|((i, field), binding)| {
487            let field_id = field_id(field, i);
488            quote! { #field_id: ref #binding, }
489        })
490        .collect()
491}
492
493fn bindings() -> impl Iterator<Item = Ident> {
494    (0..).map(|i| Ident::new(&format!("x{i}"), Span::call_site()))
495}
496
497fn trait_bounds<I>(crate_path: &Path, types: I) -> impl Iterator<Item = PredicateType>
498where
499    I: IntoIterator<Item = (Type, TypeRequirements)>,
500{
501    let crate_path = crate_path.clone();
502    types
503        .into_iter()
504        .map(move |(bounded_ty, requirements)| PredicateType {
505            lifetimes: None,
506            bounded_ty,
507            colon_token: Default::default(),
508            bounds: requirements
509                .into_iter()
510                .map(|req| match req {
511                    TypeRequirement::Clone => clone_trait_path(),
512                    TypeRequirement::Sequence => trait_path(&crate_path),
513                })
514                .map(trait_bound)
515                .collect(),
516        })
517}
518
519fn trait_bound(path: Path) -> TypeParamBound {
520    TypeParamBound::Trait(TraitBound {
521        paren_token: None,
522        modifier: TraitBoundModifier::None,
523        lifetimes: None,
524        path,
525    })
526}
527
528fn trait_path(crate_path: &Path) -> Path {
529    let mut path = crate_path.clone();
530    path.segments
531        .push(Ident::new("Sequence", Span::call_site()).into());
532    path
533}
534
535fn clone_trait_path() -> Path {
536    Path {
537        leading_colon: Some(Default::default()),
538        segments: [
539            PathSegment::from(Ident::new("core", Span::call_site())),
540            Ident::new("clone", Span::call_site()).into(),
541            Ident::new("Clone", Span::call_site()).into(),
542        ]
543        .into_iter()
544        .collect(),
545    }
546}
547
548fn tuple_type_requirements() -> impl Iterator<Item = TypeRequirements> {
549    once([TypeRequirement::Sequence].into()).chain(repeat(
550        [TypeRequirement::Sequence, TypeRequirement::Clone].into(),
551    ))
552}
553
554fn group_type_requirements<'a, I>(bounds: I) -> Vec<(Type, TypeRequirements)>
555where
556    I: IntoIterator<Item = (&'a Field, TypeRequirements)>,
557{
558    bounds
559        .into_iter()
560        .fold(
561            (HashMap::<_, usize>::new(), Vec::new()),
562            |(mut indexes, mut acc), (field, requirements)| {
563                let i = *indexes.entry(field.ty.clone()).or_insert_with(|| {
564                    acc.push((field.ty.clone(), TypeRequirements::new()));
565                    acc.len() - 1
566                });
567                acc[i].1.extend(requirements);
568                (indexes, acc)
569            },
570        )
571        .1
572}
573
574#[derive(Clone, Copy, Debug, PartialEq)]
575enum TypeRequirement {
576    Sequence,
577    Clone,
578}
579
580#[derive(Clone, Debug, Default, PartialEq)]
581struct TypeRequirements(u8);
582
583impl TypeRequirements {
584    const SEQUENCE: u8 = 0x1;
585    const CLONE: u8 = 0x2;
586
587    fn new() -> Self {
588        Self::default()
589    }
590
591    fn insert(&mut self, req: TypeRequirement) {
592        self.0 |= Self::enum_to_mask(req);
593    }
594
595    fn into_iter(self) -> impl Iterator<Item = TypeRequirement> {
596        let mut n = self.0;
597        iter::from_fn(move || {
598            if n & Self::SEQUENCE != 0 {
599                n &= !Self::SEQUENCE;
600                Some(TypeRequirement::Sequence)
601            } else if n & Self::CLONE != 0 {
602                n &= !Self::CLONE;
603                Some(TypeRequirement::Clone)
604            } else {
605                None
606            }
607        })
608    }
609
610    fn extend(&mut self, other: Self) {
611        self.0 |= other.0;
612    }
613
614    fn enum_to_mask(req: TypeRequirement) -> u8 {
615        match req {
616            TypeRequirement::Sequence => Self::SEQUENCE,
617            TypeRequirement::Clone => Self::CLONE,
618        }
619    }
620}
621
622impl<const N: usize> From<[TypeRequirement; N]> for TypeRequirements {
623    fn from(reqs: [TypeRequirement; N]) -> Self {
624        reqs.into_iter()
625            .fold(TypeRequirements::new(), |mut acc, req| {
626                acc.insert(req);
627                acc
628            })
629    }
630}
631
632#[derive(Clone, Copy, Debug, Eq, PartialEq)]
633enum Direction {
634    Forward,
635    Backward,
636}
637
638impl Direction {
639    fn advance(self) -> Ident {
640        let s = match self {
641            Direction::Forward => "next",
642            Direction::Backward => "previous",
643        };
644        Ident::new(s, Span::call_site())
645    }
646
647    fn reset(self) -> Ident {
648        let s = match self {
649            Direction::Forward => "first",
650            Direction::Backward => "last",
651        };
652        Ident::new(s, Span::call_site())
653    }
654}
655
656#[derive(Debug)]
657enum Error {
658    UnsupportedUnion,
659}
660
661impl Error {
662    fn with_tokens<T: ToTokens>(self, tokens: T) -> syn::Error {
663        syn::Error::new_spanned(tokens, self)
664    }
665}
666
667impl Display for Error {
668    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
669        match self {
670            Error::UnsupportedUnion => f.write_str("Sequence cannot be derived for union types"),
671        }
672    }
673}
674
675#[cfg(test)]
676mod tests {
677    use crate::DeriveOptions;
678    use quote::quote;
679
680    #[test]
681    fn crate_path_can_be_parsed() {
682        let input: syn::DeriveInput = syn::parse2(quote! {
683            #[derive(Sequence)]
684            #[enum_iterator(crate = foo::bar)]
685            struct Foo;
686        })
687        .unwrap();
688        let options = DeriveOptions::parse(&input.attrs).unwrap();
689        let expected_path: syn::Path = syn::parse2(quote! { foo::bar }).unwrap();
690        assert_eq!(options.crate_path, expected_path);
691    }
692}