enum_iterator_derive/
lib.rs

1// Copyright (c) 2018-2022 Stephane Raux. Distributed under the 0BSD license.
2
3//! # Overview
4//! - [📦 crates.io](https://crates.io/crates/enum-iterator-derive)
5//! - [📖 Documentation](https://docs.rs/enum-iterator-derive)
6//! - [⚖ 0BSD license](https://spdx.org/licenses/0BSD.html)
7//!
8//! Procedural macro to derive `Sequence`.
9//!
10//! See crate [`enum-iterator`](https://docs.rs/enum-iterator) for details.
11//!
12//! # Contribute
13//! All contributions shall be licensed under the [0BSD license](https://spdx.org/licenses/0BSD.html).
14
15#![recursion_limit = "128"]
16#![deny(warnings)]
17
18extern crate proc_macro;
19
20use proc_macro2::{Span, TokenStream};
21use quote::{quote, ToTokens};
22use std::{
23    collections::HashMap,
24    fmt::{self, Display},
25    iter::{self, once, repeat},
26};
27use syn::{
28    punctuated::Punctuated, token::Comma, DeriveInput, Field, Fields, Generics, Ident, Member,
29    Path, PathSegment, PredicateType, TraitBound, TraitBoundModifier, Type, TypeParamBound,
30    Variant, WhereClause, WherePredicate,
31};
32
33/// Derives `Sequence`.
34#[proc_macro_derive(Sequence)]
35pub fn derive_sequence(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
36    derive(input)
37        .unwrap_or_else(|e| e.to_compile_error())
38        .into()
39}
40
41fn derive(input: proc_macro::TokenStream) -> Result<TokenStream, syn::Error> {
42    derive_for_ast(syn::parse(input)?)
43}
44
45fn derive_for_ast(ast: DeriveInput) -> Result<TokenStream, syn::Error> {
46    let ty = &ast.ident;
47    let generics = &ast.generics;
48    match &ast.data {
49        syn::Data::Struct(s) => derive_for_struct(ty, generics, &s.fields),
50        syn::Data::Enum(e) => derive_for_enum(ty, generics, &e.variants),
51        syn::Data::Union(_) => Err(Error::UnsupportedUnion.with_tokens(&ast)),
52    }
53}
54
55fn derive_for_struct(
56    ty: &Ident,
57    generics: &Generics,
58    fields: &Fields,
59) -> Result<TokenStream, syn::Error> {
60    let cardinality = tuple_cardinality(fields);
61    let first = init_value(ty, None, fields, Direction::Forward);
62    let last = init_value(ty, None, fields, Direction::Backward);
63    let next_body = advance_struct(ty, fields, Direction::Forward);
64    let previous_body = advance_struct(ty, fields, Direction::Backward);
65    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
66    let where_clause = if generics.params.is_empty() {
67        where_clause.cloned()
68    } else {
69        let mut clause = where_clause.cloned().unwrap_or_else(|| WhereClause {
70            where_token: Default::default(),
71            predicates: Default::default(),
72        });
73        clause.predicates.extend(
74            trait_bounds(group_type_requirements(
75                fields.iter().rev().zip(tuple_type_requirements()),
76            ))
77            .map(WherePredicate::Type),
78        );
79        Some(clause)
80    };
81    let tokens = quote! {
82        impl #impl_generics ::enum_iterator::Sequence for #ty #ty_generics #where_clause {
83            #[allow(clippy::identity_op)]
84            const CARDINALITY: usize = #cardinality;
85
86            fn next(&self) -> ::core::option::Option<Self> {
87                #next_body
88            }
89
90            fn previous(&self) -> ::core::option::Option<Self> {
91                #previous_body
92            }
93
94            fn first() -> ::core::option::Option<Self> {
95                #first
96            }
97
98            fn last() -> ::core::option::Option<Self> {
99                #last
100            }
101        }
102    };
103    Ok(tokens)
104}
105
106fn derive_for_enum(
107    ty: &Ident,
108    generics: &Generics,
109    variants: &Punctuated<Variant, Comma>,
110) -> Result<TokenStream, syn::Error> {
111    let cardinality = enum_cardinality(variants);
112    let next_body = advance_enum(ty, variants, Direction::Forward);
113    let previous_body = advance_enum(ty, variants, Direction::Backward);
114    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
115    let where_clause = if generics.params.is_empty() {
116        where_clause.cloned()
117    } else {
118        let mut clause = where_clause.cloned().unwrap_or_else(|| WhereClause {
119            where_token: Default::default(),
120            predicates: Default::default(),
121        });
122        clause.predicates.extend(
123            trait_bounds(group_type_requirements(variants.iter().flat_map(
124                |variant| variant.fields.iter().rev().zip(tuple_type_requirements()),
125            )))
126            .map(WherePredicate::Type),
127        );
128        Some(clause)
129    };
130    let next_variant_body = next_variant(ty, variants, Direction::Forward);
131    let previous_variant_body = next_variant(ty, variants, Direction::Backward);
132    let (first, last) = if variants.is_empty() {
133        (
134            quote! { ::core::option::Option::None },
135            quote! { ::core::option::Option::None },
136        )
137    } else {
138        let last_index = variants.len() - 1;
139        (
140            quote! { next_variant(0) },
141            quote! { previous_variant(#last_index) },
142        )
143    };
144    let tokens = quote! {
145        impl #impl_generics ::enum_iterator::Sequence for #ty #ty_generics #where_clause {
146            #[allow(clippy::identity_op)]
147            const CARDINALITY: usize = #cardinality;
148
149            fn next(&self) -> ::core::option::Option<Self> {
150                #next_body
151            }
152
153            fn previous(&self) -> ::core::option::Option<Self> {
154                #previous_body
155            }
156
157            fn first() -> ::core::option::Option<Self> {
158                #first
159            }
160
161            fn last() -> ::core::option::Option<Self> {
162                #last
163            }
164        }
165
166        fn next_variant #impl_generics(
167            mut i: usize,
168        ) -> ::core::option::Option<#ty #ty_generics> #where_clause {
169            #next_variant_body
170        }
171
172        fn previous_variant #impl_generics(
173            mut i: usize,
174        ) -> ::core::option::Option<#ty #ty_generics> #where_clause {
175            #previous_variant_body
176        }
177    };
178    let tokens = quote! {
179        const _: () = { #tokens };
180    };
181    Ok(tokens)
182}
183
184fn enum_cardinality(variants: &Punctuated<Variant, Comma>) -> TokenStream {
185    let terms = variants
186        .iter()
187        .map(|variant| tuple_cardinality(&variant.fields));
188    quote! {
189        #((#terms) +)* 0
190    }
191}
192
193fn tuple_cardinality(fields: &Fields) -> TokenStream {
194    let factors = fields.iter().map(|field| {
195        let ty = &field.ty;
196        quote! {
197            <#ty as ::enum_iterator::Sequence>::CARDINALITY
198        }
199    });
200    quote! {
201        #(#factors *)* 1
202    }
203}
204
205fn field_id(field: &Field, index: usize) -> Member {
206    field
207        .ident
208        .clone()
209        .map_or_else(|| Member::from(index), Member::from)
210}
211
212fn init_value(
213    ty: &Ident,
214    variant: Option<&Ident>,
215    fields: &Fields,
216    direction: Direction,
217) -> TokenStream {
218    let id = variant.map_or_else(|| quote! { #ty }, |v| quote! { #ty::#v });
219    if fields.is_empty() {
220        quote! {
221            ::core::option::Option::Some(#id {})
222        }
223    } else {
224        let reset = direction.reset();
225        let initialization =
226            repeat(quote! { ::enum_iterator::Sequence::#reset() }).take(fields.len());
227        let assignments = field_assignments(fields);
228        let bindings = bindings().take(fields.len());
229        quote! {{
230            match (#(#initialization,)*) {
231                (#(::core::option::Option::Some(#bindings),)*) => {
232                    ::core::option::Option::Some(#id { #assignments })
233                }
234                _ => ::core::option::Option::None,
235            }
236        }}
237    }
238}
239
240fn next_variant(
241    ty: &Ident,
242    variants: &Punctuated<Variant, Comma>,
243    direction: Direction,
244) -> TokenStream {
245    let advance = match direction {
246        Direction::Forward => {
247            let last_index = variants.len().saturating_sub(1);
248            quote! {
249                if i >= #last_index { break ::core::option::Option::None; } else { i+= 1; }
250            }
251        }
252        Direction::Backward => quote! {
253            if i == 0 { break ::core::option::Option::None; } else { i -= 1; }
254        },
255    };
256    let arms = variants.iter().enumerate().map(|(i, v)| {
257        let id = &v.ident;
258        let init = init_value(ty, Some(id), &v.fields, direction);
259        quote! {
260            #i => #init
261        }
262    });
263    quote! {
264        loop {
265            let next = match i {
266                #(#arms,)*
267                _ => ::core::option::Option::None,
268            };
269            match next {
270                ::core::option::Option::Some(_) => break next,
271                ::core::option::Option::None => #advance,
272            }
273        }
274    }
275}
276
277fn advance_struct(ty: &Ident, fields: &Fields, direction: Direction) -> TokenStream {
278    let assignments = field_assignments(fields);
279    let bindings = bindings().take(fields.len()).collect::<Vec<_>>();
280    let tuple = advance_tuple(&bindings, direction);
281    quote! {
282        let #ty { #assignments } = self;
283        let (#(#bindings,)*) = #tuple?;
284        ::core::option::Option::Some(#ty { #assignments })
285    }
286}
287
288fn advance_enum(
289    ty: &Ident,
290    variants: &Punctuated<Variant, Comma>,
291    direction: Direction,
292) -> TokenStream {
293    let arms: Vec<_> = match direction {
294        Direction::Forward => variants
295            .iter()
296            .enumerate()
297            .map(|(i, variant)| advance_enum_arm(ty, direction, i, variant))
298            .collect(),
299        Direction::Backward => variants
300            .iter()
301            .enumerate()
302            .rev()
303            .map(|(i, variant)| advance_enum_arm(ty, direction, i, variant))
304            .collect(),
305    };
306    quote! {
307        match *self {
308            #(#arms,)*
309        }
310    }
311}
312
313fn advance_enum_arm(ty: &Ident, direction: Direction, i: usize, variant: &Variant) -> TokenStream {
314    let next = match direction {
315        Direction::Forward => match i.checked_add(1) {
316            Some(next_i) => quote! { next_variant(#next_i) },
317            None => quote! { ::core::option::Option::None },
318        },
319        Direction::Backward => match i.checked_sub(1) {
320            Some(prev_i) => quote! { previous_variant(#prev_i) },
321            None => quote! { ::core::option::Option::None },
322        },
323    };
324    let id = &variant.ident;
325    if variant.fields.is_empty() {
326        quote! {
327            #ty::#id {} => #next
328        }
329    } else {
330        let destructuring = field_bindings(&variant.fields);
331        let assignments = field_assignments(&variant.fields);
332        let bindings = bindings().take(variant.fields.len()).collect::<Vec<_>>();
333        let tuple = advance_tuple(&bindings, direction);
334        quote! {
335            #ty::#id { #destructuring } => {
336                let y = #tuple;
337                match y {
338                    ::core::option::Option::Some((#(#bindings,)*)) => {
339                        ::core::option::Option::Some(#ty::#id { #assignments })
340                    }
341                    ::core::option::Option::None => #next,
342                }
343            }
344        }
345    }
346}
347
348fn advance_tuple(bindings: &[Ident], direction: Direction) -> TokenStream {
349    let advance = direction.advance();
350    let reset = direction.reset();
351    let rev_bindings = bindings.iter().rev().collect::<Vec<_>>();
352    let (rev_binding_head, rev_binding_tail) = match rev_bindings.split_first() {
353        Some((&head, tail)) => (Some(head), tail),
354        None => (None, &*rev_bindings),
355    };
356    let rev_binding_head = match rev_binding_head {
357        Some(head) => quote! {
358            let (#head, carry) = match ::enum_iterator::Sequence::#advance(#head) {
359                ::core::option::Option::Some(#head) => (::core::option::Option::Some(#head), false),
360                ::core::option::Option::None => (::enum_iterator::Sequence::#reset(), true),
361            };
362        },
363        None => quote! {
364            let carry = true;
365        },
366    };
367    let body = quote! {
368        #rev_binding_head
369        #(
370            let (#rev_binding_tail, carry) = if carry {
371                match ::enum_iterator::Sequence::#advance(#rev_binding_tail) {
372                    ::core::option::Option::Some(#rev_binding_tail) => {
373                        (::core::option::Option::Some(#rev_binding_tail), false)
374                    }
375                    ::core::option::Option::None => (::enum_iterator::Sequence::#reset(), true),
376                }
377            } else {
378                (
379                    ::core::option::Option::Some(::core::clone::Clone::clone(#rev_binding_tail)),
380                    false,
381                )
382            };
383        )*
384        if carry {
385            ::core::option::Option::None
386        } else {
387            match (#(#bindings,)*) {
388                (#(::core::option::Option::Some(#bindings),)*) => {
389                    ::core::option::Option::Some((#(#bindings,)*))
390                }
391                _ => ::core::option::Option::None,
392            }
393        }
394    };
395    quote! {
396        { #body }
397    }
398}
399
400fn field_assignments<'a, I>(fields: I) -> TokenStream
401where
402    I: IntoIterator<Item = &'a Field>,
403{
404    fields
405        .into_iter()
406        .enumerate()
407        .zip(bindings())
408        .map(|((i, field), binding)| {
409            let field_id = field_id(field, i);
410            quote! { #field_id: #binding, }
411        })
412        .collect()
413}
414
415fn field_bindings<'a, I>(fields: I) -> TokenStream
416where
417    I: IntoIterator<Item = &'a Field>,
418{
419    fields
420        .into_iter()
421        .enumerate()
422        .zip(bindings())
423        .map(|((i, field), binding)| {
424            let field_id = field_id(field, i);
425            quote! { #field_id: ref #binding, }
426        })
427        .collect()
428}
429
430fn bindings() -> impl Iterator<Item = Ident> {
431    (0..).map(|i| Ident::new(&format!("x{i}"), Span::call_site()))
432}
433
434fn trait_bounds<I>(types: I) -> impl Iterator<Item = PredicateType>
435where
436    I: IntoIterator<Item = (Type, TypeRequirements)>,
437{
438    types
439        .into_iter()
440        .map(|(bounded_ty, requirements)| PredicateType {
441            lifetimes: None,
442            bounded_ty,
443            colon_token: Default::default(),
444            bounds: requirements
445                .into_iter()
446                .map(|req| match req {
447                    TypeRequirement::Clone => clone_trait_path(),
448                    TypeRequirement::Sequence => trait_path(),
449                })
450                .map(trait_bound)
451                .collect(),
452        })
453}
454
455fn trait_bound(path: Path) -> TypeParamBound {
456    TypeParamBound::Trait(TraitBound {
457        paren_token: None,
458        modifier: TraitBoundModifier::None,
459        lifetimes: None,
460        path,
461    })
462}
463
464fn trait_path() -> Path {
465    Path {
466        leading_colon: Some(Default::default()),
467        segments: [
468            PathSegment::from(Ident::new("enum_iterator", Span::call_site())),
469            Ident::new("Sequence", Span::call_site()).into(),
470        ]
471        .into_iter()
472        .collect(),
473    }
474}
475
476fn clone_trait_path() -> Path {
477    Path {
478        leading_colon: Some(Default::default()),
479        segments: [
480            PathSegment::from(Ident::new("core", Span::call_site())),
481            Ident::new("clone", Span::call_site()).into(),
482            Ident::new("Clone", Span::call_site()).into(),
483        ]
484        .into_iter()
485        .collect(),
486    }
487}
488
489fn tuple_type_requirements() -> impl Iterator<Item = TypeRequirements> {
490    once([TypeRequirement::Sequence].into()).chain(repeat(
491        [TypeRequirement::Sequence, TypeRequirement::Clone].into(),
492    ))
493}
494
495fn group_type_requirements<'a, I>(bounds: I) -> Vec<(Type, TypeRequirements)>
496where
497    I: IntoIterator<Item = (&'a Field, TypeRequirements)>,
498{
499    bounds
500        .into_iter()
501        .fold(
502            (HashMap::<_, usize>::new(), Vec::new()),
503            |(mut indexes, mut acc), (field, requirements)| {
504                let i = *indexes.entry(field.ty.clone()).or_insert_with(|| {
505                    acc.push((field.ty.clone(), TypeRequirements::new()));
506                    acc.len() - 1
507                });
508                acc[i].1.extend(requirements);
509                (indexes, acc)
510            },
511        )
512        .1
513}
514
515#[derive(Clone, Copy, Debug, PartialEq)]
516enum TypeRequirement {
517    Sequence,
518    Clone,
519}
520
521#[derive(Clone, Debug, Default, PartialEq)]
522struct TypeRequirements(u8);
523
524impl TypeRequirements {
525    const SEQUENCE: u8 = 0x1;
526    const CLONE: u8 = 0x2;
527
528    fn new() -> Self {
529        Self::default()
530    }
531
532    fn insert(&mut self, req: TypeRequirement) {
533        self.0 |= Self::enum_to_mask(req);
534    }
535
536    fn into_iter(self) -> impl Iterator<Item = TypeRequirement> {
537        let mut n = self.0;
538        iter::from_fn(move || {
539            if n & Self::SEQUENCE != 0 {
540                n &= !Self::SEQUENCE;
541                Some(TypeRequirement::Sequence)
542            } else if n & Self::CLONE != 0 {
543                n &= !Self::CLONE;
544                Some(TypeRequirement::Clone)
545            } else {
546                None
547            }
548        })
549    }
550
551    fn extend(&mut self, other: Self) {
552        self.0 |= other.0;
553    }
554
555    fn enum_to_mask(req: TypeRequirement) -> u8 {
556        match req {
557            TypeRequirement::Sequence => Self::SEQUENCE,
558            TypeRequirement::Clone => Self::CLONE,
559        }
560    }
561}
562
563impl<const N: usize> From<[TypeRequirement; N]> for TypeRequirements {
564    fn from(reqs: [TypeRequirement; N]) -> Self {
565        reqs.into_iter()
566            .fold(TypeRequirements::new(), |mut acc, req| {
567                acc.insert(req);
568                acc
569            })
570    }
571}
572
573#[derive(Clone, Copy, Debug, Eq, PartialEq)]
574enum Direction {
575    Forward,
576    Backward,
577}
578
579impl Direction {
580    fn advance(self) -> Ident {
581        let s = match self {
582            Direction::Forward => "next",
583            Direction::Backward => "previous",
584        };
585        Ident::new(s, Span::call_site())
586    }
587
588    fn reset(self) -> Ident {
589        let s = match self {
590            Direction::Forward => "first",
591            Direction::Backward => "last",
592        };
593        Ident::new(s, Span::call_site())
594    }
595}
596
597#[derive(Debug)]
598enum Error {
599    UnsupportedUnion,
600}
601
602impl Error {
603    fn with_tokens<T: ToTokens>(self, tokens: T) -> syn::Error {
604        syn::Error::new_spanned(tokens, self)
605    }
606}
607
608impl Display for Error {
609    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
610        match self {
611            Error::UnsupportedUnion => f.write_str("Sequence cannot be derived for union types"),
612        }
613    }
614}