1#![recursion_limit = "128"]
16#![deny(warnings)]
17
18extern crate proc_macro;
19
20use proc_macro2::{Span, TokenStream};
21use quote::{quote, ToTokens};
22use std::{
23 collections::HashMap,
24 fmt::{self, Display},
25 iter::{self, once, repeat, repeat_n},
26};
27use syn::{
28 punctuated::Punctuated, token::Comma, DeriveInput, Field, Fields, Generics, Ident, Member,
29 Path, PathSegment, PredicateType, TraitBound, TraitBoundModifier, Type, TypeParamBound,
30 Variant, WhereClause, WherePredicate,
31};
32
33#[proc_macro_derive(Sequence, attributes(enum_iterator))]
35pub fn derive_sequence(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
36 derive(input)
37 .unwrap_or_else(|e| e.to_compile_error())
38 .into()
39}
40
41fn derive(input: proc_macro::TokenStream) -> Result<TokenStream, syn::Error> {
42 derive_for_ast(syn::parse(input)?)
43}
44
45#[derive(Debug)]
46struct DeriveOptions {
47 crate_path: Path,
48}
49
50impl DeriveOptions {
51 fn parse(attrs: &[syn::Attribute]) -> Result<Self, syn::Error> {
52 let mut crate_path = None;
53 attrs
54 .iter()
55 .filter(|attr| attr.path().is_ident("enum_iterator"))
56 .try_for_each(|attr| {
57 attr.parse_nested_meta(|meta| {
58 if meta.path.is_ident("crate") {
59 let path: Path = meta.value()?.parse()?;
60 if crate_path.is_none() {
61 crate_path = Some(path);
62 Ok(())
63 } else {
64 Err(meta.error("duplicate crate key"))
65 }
66 } else {
67 Err(meta.error(format!("unknown key {}", meta.path.to_token_stream())))
68 }
69 })
70 })?;
71 Ok(Self {
72 crate_path: crate_path.unwrap_or_else(|| Path {
73 leading_colon: Some(Default::default()),
74 segments: [PathSegment::from(Ident::new(
75 "enum_iterator",
76 Span::call_site(),
77 ))]
78 .into_iter()
79 .collect(),
80 }),
81 })
82 }
83}
84
85fn derive_for_ast(ast: DeriveInput) -> Result<TokenStream, syn::Error> {
86 let ty = &ast.ident;
87 let generics = &ast.generics;
88 let options = DeriveOptions::parse(&ast.attrs)?;
89 match &ast.data {
90 syn::Data::Struct(s) => derive_for_struct(&options, ty, generics, &s.fields),
91 syn::Data::Enum(e) => derive_for_enum(&options, ty, generics, &e.variants),
92 syn::Data::Union(_) => Err(Error::UnsupportedUnion.with_tokens(&ast)),
93 }
94}
95
96fn derive_for_struct(
97 options: &DeriveOptions,
98 ty: &Ident,
99 generics: &Generics,
100 fields: &Fields,
101) -> Result<TokenStream, syn::Error> {
102 let crate_path = &options.crate_path;
103 let cardinality = tuple_cardinality(&options.crate_path, fields);
104 let first = init_value(&options.crate_path, ty, None, fields, Direction::Forward);
105 let last = init_value(&options.crate_path, ty, None, fields, Direction::Backward);
106 let next_body = advance_struct(&options.crate_path, ty, fields, Direction::Forward);
107 let previous_body = advance_struct(&options.crate_path, ty, fields, Direction::Backward);
108 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
109 let where_clause = if generics.params.is_empty() {
110 where_clause.cloned()
111 } else {
112 let mut clause = where_clause.cloned().unwrap_or_else(|| WhereClause {
113 where_token: Default::default(),
114 predicates: Default::default(),
115 });
116 clause.predicates.extend(
117 trait_bounds(
118 &options.crate_path,
119 group_type_requirements(fields.iter().rev().zip(tuple_type_requirements())),
120 )
121 .map(WherePredicate::Type),
122 );
123 Some(clause)
124 };
125 let tokens = quote! {
126 impl #impl_generics #crate_path::Sequence for #ty #ty_generics #where_clause {
127 #[allow(clippy::identity_op)]
128 const CARDINALITY: usize = #cardinality;
129
130 fn next(&self) -> ::core::option::Option<Self> {
131 #next_body
132 }
133
134 fn previous(&self) -> ::core::option::Option<Self> {
135 #previous_body
136 }
137
138 fn first() -> ::core::option::Option<Self> {
139 #first
140 }
141
142 fn last() -> ::core::option::Option<Self> {
143 #last
144 }
145 }
146 };
147 Ok(tokens)
148}
149
150fn derive_for_enum(
151 options: &DeriveOptions,
152 ty: &Ident,
153 generics: &Generics,
154 variants: &Punctuated<Variant, Comma>,
155) -> Result<TokenStream, syn::Error> {
156 let cardinality = enum_cardinality(&options.crate_path, variants);
157 let next_body = advance_enum(&options.crate_path, ty, variants, Direction::Forward);
158 let previous_body = advance_enum(&options.crate_path, ty, variants, Direction::Backward);
159 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
160 let where_clause = if generics.params.is_empty() {
161 where_clause.cloned()
162 } else {
163 let mut clause = where_clause.cloned().unwrap_or_else(|| WhereClause {
164 where_token: Default::default(),
165 predicates: Default::default(),
166 });
167 clause.predicates.extend(
168 trait_bounds(
169 &options.crate_path,
170 group_type_requirements(variants.iter().flat_map(|variant| {
171 variant.fields.iter().rev().zip(tuple_type_requirements())
172 })),
173 )
174 .map(WherePredicate::Type),
175 );
176 Some(clause)
177 };
178 let next_variant_body = next_variant(&options.crate_path, ty, variants, Direction::Forward);
179 let previous_variant_body =
180 next_variant(&options.crate_path, ty, variants, Direction::Backward);
181 let (first, last) = if variants.is_empty() {
182 (
183 quote! { ::core::option::Option::None },
184 quote! { ::core::option::Option::None },
185 )
186 } else {
187 let last_index = variants.len() - 1;
188 (
189 quote! { next_variant(0) },
190 quote! { previous_variant(#last_index) },
191 )
192 };
193 let crate_path = &options.crate_path;
194 let tokens = quote! {
195 impl #impl_generics #crate_path::Sequence for #ty #ty_generics #where_clause {
196 #[allow(clippy::identity_op)]
197 const CARDINALITY: usize = #cardinality;
198
199 fn next(&self) -> ::core::option::Option<Self> {
200 #next_body
201 }
202
203 fn previous(&self) -> ::core::option::Option<Self> {
204 #previous_body
205 }
206
207 fn first() -> ::core::option::Option<Self> {
208 #first
209 }
210
211 fn last() -> ::core::option::Option<Self> {
212 #last
213 }
214 }
215
216 fn next_variant #impl_generics(
217 mut i: usize,
218 ) -> ::core::option::Option<#ty #ty_generics> #where_clause {
219 #next_variant_body
220 }
221
222 fn previous_variant #impl_generics(
223 mut i: usize,
224 ) -> ::core::option::Option<#ty #ty_generics> #where_clause {
225 #previous_variant_body
226 }
227 };
228 let tokens = quote! {
229 const _: () = { #tokens };
230 };
231 Ok(tokens)
232}
233
234fn enum_cardinality(crate_path: &Path, variants: &Punctuated<Variant, Comma>) -> TokenStream {
235 let terms = variants
236 .iter()
237 .map(|variant| tuple_cardinality(crate_path, &variant.fields));
238 quote! {
239 #((#terms) +)* 0
240 }
241}
242
243fn tuple_cardinality(crate_path: &Path, fields: &Fields) -> TokenStream {
244 let factors = fields.iter().map(|field| {
245 let ty = &field.ty;
246 quote! {
247 <#ty as #crate_path::Sequence>::CARDINALITY
248 }
249 });
250 quote! {
251 #(#factors *)* 1
252 }
253}
254
255fn field_id(field: &Field, index: usize) -> Member {
256 field
257 .ident
258 .clone()
259 .map_or_else(|| Member::from(index), Member::from)
260}
261
262fn init_value(
263 crate_path: &Path,
264 ty: &Ident,
265 variant: Option<&Ident>,
266 fields: &Fields,
267 direction: Direction,
268) -> TokenStream {
269 let id = variant.map_or_else(|| quote! { #ty }, |v| quote! { #ty::#v });
270 if fields.is_empty() {
271 quote! {
272 ::core::option::Option::Some(#id {})
273 }
274 } else {
275 let reset = direction.reset();
276 let initialization = repeat_n(quote! { #crate_path::Sequence::#reset() }, fields.len());
277 let assignments = field_assignments(fields);
278 let bindings = bindings().take(fields.len());
279 quote! {{
280 match (#(#initialization,)*) {
281 (#(::core::option::Option::Some(#bindings),)*) => {
282 ::core::option::Option::Some(#id { #assignments })
283 }
284 _ => ::core::option::Option::None,
285 }
286 }}
287 }
288}
289
290fn next_variant(
291 crate_path: &Path,
292 ty: &Ident,
293 variants: &Punctuated<Variant, Comma>,
294 direction: Direction,
295) -> TokenStream {
296 let advance = match direction {
297 Direction::Forward => {
298 let last_index = variants.len().saturating_sub(1);
299 quote! {
300 if i >= #last_index { break ::core::option::Option::None; } else { i+= 1; }
301 }
302 }
303 Direction::Backward => quote! {
304 if i == 0 { break ::core::option::Option::None; } else { i -= 1; }
305 },
306 };
307 let arms = variants.iter().enumerate().map(|(i, v)| {
308 let id = &v.ident;
309 let init = init_value(crate_path, ty, Some(id), &v.fields, direction);
310 quote! {
311 #i => #init
312 }
313 });
314 quote! {
315 loop {
316 let next = match i {
317 #(#arms,)*
318 _ => ::core::option::Option::None,
319 };
320 match next {
321 ::core::option::Option::Some(_) => break next,
322 ::core::option::Option::None => #advance,
323 }
324 }
325 }
326}
327
328fn advance_struct(
329 crate_path: &Path,
330 ty: &Ident,
331 fields: &Fields,
332 direction: Direction,
333) -> TokenStream {
334 let assignments = field_assignments(fields);
335 let bindings = bindings().take(fields.len()).collect::<Vec<_>>();
336 let tuple = advance_tuple(crate_path, &bindings, direction);
337 quote! {
338 let #ty { #assignments } = self;
339 let (#(#bindings,)*) = #tuple?;
340 ::core::option::Option::Some(#ty { #assignments })
341 }
342}
343
344fn advance_enum(
345 crate_path: &Path,
346 ty: &Ident,
347 variants: &Punctuated<Variant, Comma>,
348 direction: Direction,
349) -> TokenStream {
350 let arms: Vec<_> = match direction {
351 Direction::Forward => variants
352 .iter()
353 .enumerate()
354 .map(|(i, variant)| advance_enum_arm(crate_path, ty, direction, i, variant))
355 .collect(),
356 Direction::Backward => variants
357 .iter()
358 .enumerate()
359 .rev()
360 .map(|(i, variant)| advance_enum_arm(crate_path, ty, direction, i, variant))
361 .collect(),
362 };
363 quote! {
364 match *self {
365 #(#arms,)*
366 }
367 }
368}
369
370fn advance_enum_arm(
371 crate_path: &Path,
372 ty: &Ident,
373 direction: Direction,
374 i: usize,
375 variant: &Variant,
376) -> TokenStream {
377 let next = match direction {
378 Direction::Forward => match i.checked_add(1) {
379 Some(next_i) => quote! { next_variant(#next_i) },
380 None => quote! { ::core::option::Option::None },
381 },
382 Direction::Backward => match i.checked_sub(1) {
383 Some(prev_i) => quote! { previous_variant(#prev_i) },
384 None => quote! { ::core::option::Option::None },
385 },
386 };
387 let id = &variant.ident;
388 if variant.fields.is_empty() {
389 quote! {
390 #ty::#id {} => #next
391 }
392 } else {
393 let destructuring = field_bindings(&variant.fields);
394 let assignments = field_assignments(&variant.fields);
395 let bindings = bindings().take(variant.fields.len()).collect::<Vec<_>>();
396 let tuple = advance_tuple(crate_path, &bindings, direction);
397 quote! {
398 #ty::#id { #destructuring } => {
399 let y = #tuple;
400 match y {
401 ::core::option::Option::Some((#(#bindings,)*)) => {
402 ::core::option::Option::Some(#ty::#id { #assignments })
403 }
404 ::core::option::Option::None => #next,
405 }
406 }
407 }
408 }
409}
410
411fn advance_tuple(crate_path: &Path, bindings: &[Ident], direction: Direction) -> TokenStream {
412 let advance = direction.advance();
413 let reset = direction.reset();
414 let rev_bindings = bindings.iter().rev().collect::<Vec<_>>();
415 let (rev_binding_head, rev_binding_tail) = match rev_bindings.split_first() {
416 Some((&head, tail)) => (Some(head), tail),
417 None => (None, &*rev_bindings),
418 };
419 let rev_binding_head = match rev_binding_head {
420 Some(head) => quote! {
421 let (#head, carry) = match #crate_path::Sequence::#advance(#head) {
422 ::core::option::Option::Some(#head) => (::core::option::Option::Some(#head), false),
423 ::core::option::Option::None => (#crate_path::Sequence::#reset(), true),
424 };
425 },
426 None => quote! {
427 let carry = true;
428 },
429 };
430 let body = quote! {
431 #rev_binding_head
432 #(
433 let (#rev_binding_tail, carry) = if carry {
434 match #crate_path::Sequence::#advance(#rev_binding_tail) {
435 ::core::option::Option::Some(#rev_binding_tail) => {
436 (::core::option::Option::Some(#rev_binding_tail), false)
437 }
438 ::core::option::Option::None => (#crate_path::Sequence::#reset(), true),
439 }
440 } else {
441 (
442 ::core::option::Option::Some(::core::clone::Clone::clone(#rev_binding_tail)),
443 false,
444 )
445 };
446 )*
447 if carry {
448 ::core::option::Option::None
449 } else {
450 match (#(#bindings,)*) {
451 (#(::core::option::Option::Some(#bindings),)*) => {
452 ::core::option::Option::Some((#(#bindings,)*))
453 }
454 _ => ::core::option::Option::None,
455 }
456 }
457 };
458 quote! {
459 { #body }
460 }
461}
462
463fn field_assignments<'a, I>(fields: I) -> TokenStream
464where
465 I: IntoIterator<Item = &'a Field>,
466{
467 fields
468 .into_iter()
469 .enumerate()
470 .zip(bindings())
471 .map(|((i, field), binding)| {
472 let field_id = field_id(field, i);
473 quote! { #field_id: #binding, }
474 })
475 .collect()
476}
477
478fn field_bindings<'a, I>(fields: I) -> TokenStream
479where
480 I: IntoIterator<Item = &'a Field>,
481{
482 fields
483 .into_iter()
484 .enumerate()
485 .zip(bindings())
486 .map(|((i, field), binding)| {
487 let field_id = field_id(field, i);
488 quote! { #field_id: ref #binding, }
489 })
490 .collect()
491}
492
493fn bindings() -> impl Iterator<Item = Ident> {
494 (0..).map(|i| Ident::new(&format!("x{i}"), Span::call_site()))
495}
496
497fn trait_bounds<I>(crate_path: &Path, types: I) -> impl Iterator<Item = PredicateType>
498where
499 I: IntoIterator<Item = (Type, TypeRequirements)>,
500{
501 let crate_path = crate_path.clone();
502 types
503 .into_iter()
504 .map(move |(bounded_ty, requirements)| PredicateType {
505 lifetimes: None,
506 bounded_ty,
507 colon_token: Default::default(),
508 bounds: requirements
509 .into_iter()
510 .map(|req| match req {
511 TypeRequirement::Clone => clone_trait_path(),
512 TypeRequirement::Sequence => trait_path(&crate_path),
513 })
514 .map(trait_bound)
515 .collect(),
516 })
517}
518
519fn trait_bound(path: Path) -> TypeParamBound {
520 TypeParamBound::Trait(TraitBound {
521 paren_token: None,
522 modifier: TraitBoundModifier::None,
523 lifetimes: None,
524 path,
525 })
526}
527
528fn trait_path(crate_path: &Path) -> Path {
529 let mut path = crate_path.clone();
530 path.segments
531 .push(Ident::new("Sequence", Span::call_site()).into());
532 path
533}
534
535fn clone_trait_path() -> Path {
536 Path {
537 leading_colon: Some(Default::default()),
538 segments: [
539 PathSegment::from(Ident::new("core", Span::call_site())),
540 Ident::new("clone", Span::call_site()).into(),
541 Ident::new("Clone", Span::call_site()).into(),
542 ]
543 .into_iter()
544 .collect(),
545 }
546}
547
548fn tuple_type_requirements() -> impl Iterator<Item = TypeRequirements> {
549 once([TypeRequirement::Sequence].into()).chain(repeat(
550 [TypeRequirement::Sequence, TypeRequirement::Clone].into(),
551 ))
552}
553
554fn group_type_requirements<'a, I>(bounds: I) -> Vec<(Type, TypeRequirements)>
555where
556 I: IntoIterator<Item = (&'a Field, TypeRequirements)>,
557{
558 bounds
559 .into_iter()
560 .fold(
561 (HashMap::<_, usize>::new(), Vec::new()),
562 |(mut indexes, mut acc), (field, requirements)| {
563 let i = *indexes.entry(field.ty.clone()).or_insert_with(|| {
564 acc.push((field.ty.clone(), TypeRequirements::new()));
565 acc.len() - 1
566 });
567 acc[i].1.extend(requirements);
568 (indexes, acc)
569 },
570 )
571 .1
572}
573
574#[derive(Clone, Copy, Debug, PartialEq)]
575enum TypeRequirement {
576 Sequence,
577 Clone,
578}
579
580#[derive(Clone, Debug, Default, PartialEq)]
581struct TypeRequirements(u8);
582
583impl TypeRequirements {
584 const SEQUENCE: u8 = 0x1;
585 const CLONE: u8 = 0x2;
586
587 fn new() -> Self {
588 Self::default()
589 }
590
591 fn insert(&mut self, req: TypeRequirement) {
592 self.0 |= Self::enum_to_mask(req);
593 }
594
595 fn into_iter(self) -> impl Iterator<Item = TypeRequirement> {
596 let mut n = self.0;
597 iter::from_fn(move || {
598 if n & Self::SEQUENCE != 0 {
599 n &= !Self::SEQUENCE;
600 Some(TypeRequirement::Sequence)
601 } else if n & Self::CLONE != 0 {
602 n &= !Self::CLONE;
603 Some(TypeRequirement::Clone)
604 } else {
605 None
606 }
607 })
608 }
609
610 fn extend(&mut self, other: Self) {
611 self.0 |= other.0;
612 }
613
614 fn enum_to_mask(req: TypeRequirement) -> u8 {
615 match req {
616 TypeRequirement::Sequence => Self::SEQUENCE,
617 TypeRequirement::Clone => Self::CLONE,
618 }
619 }
620}
621
622impl<const N: usize> From<[TypeRequirement; N]> for TypeRequirements {
623 fn from(reqs: [TypeRequirement; N]) -> Self {
624 reqs.into_iter()
625 .fold(TypeRequirements::new(), |mut acc, req| {
626 acc.insert(req);
627 acc
628 })
629 }
630}
631
632#[derive(Clone, Copy, Debug, Eq, PartialEq)]
633enum Direction {
634 Forward,
635 Backward,
636}
637
638impl Direction {
639 fn advance(self) -> Ident {
640 let s = match self {
641 Direction::Forward => "next",
642 Direction::Backward => "previous",
643 };
644 Ident::new(s, Span::call_site())
645 }
646
647 fn reset(self) -> Ident {
648 let s = match self {
649 Direction::Forward => "first",
650 Direction::Backward => "last",
651 };
652 Ident::new(s, Span::call_site())
653 }
654}
655
656#[derive(Debug)]
657enum Error {
658 UnsupportedUnion,
659}
660
661impl Error {
662 fn with_tokens<T: ToTokens>(self, tokens: T) -> syn::Error {
663 syn::Error::new_spanned(tokens, self)
664 }
665}
666
667impl Display for Error {
668 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
669 match self {
670 Error::UnsupportedUnion => f.write_str("Sequence cannot be derived for union types"),
671 }
672 }
673}
674
675#[cfg(test)]
676mod tests {
677 use crate::DeriveOptions;
678 use quote::quote;
679
680 #[test]
681 fn crate_path_can_be_parsed() {
682 let input: syn::DeriveInput = syn::parse2(quote! {
683 #[derive(Sequence)]
684 #[enum_iterator(crate = foo::bar)]
685 struct Foo;
686 })
687 .unwrap();
688 let options = DeriveOptions::parse(&input.attrs).unwrap();
689 let expected_path: syn::Path = syn::parse2(quote! { foo::bar }).unwrap();
690 assert_eq!(options.crate_path, expected_path);
691 }
692}