1#![recursion_limit = "128"]
16#![deny(warnings)]
17
18extern crate proc_macro;
19
20use proc_macro2::{Span, TokenStream};
21use quote::{quote, ToTokens};
22use std::{
23 collections::HashMap,
24 fmt::{self, Display},
25 iter::{self, once, repeat},
26};
27use syn::{
28 punctuated::Punctuated, token::Comma, DeriveInput, Field, Fields, Generics, Ident, Member,
29 Path, PathSegment, PredicateType, TraitBound, TraitBoundModifier, Type, TypeParamBound,
30 Variant, WhereClause, WherePredicate,
31};
32
33#[proc_macro_derive(Sequence)]
35pub fn derive_sequence(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
36 derive(input)
37 .unwrap_or_else(|e| e.to_compile_error())
38 .into()
39}
40
41fn derive(input: proc_macro::TokenStream) -> Result<TokenStream, syn::Error> {
42 derive_for_ast(syn::parse(input)?)
43}
44
45fn derive_for_ast(ast: DeriveInput) -> Result<TokenStream, syn::Error> {
46 let ty = &ast.ident;
47 let generics = &ast.generics;
48 match &ast.data {
49 syn::Data::Struct(s) => derive_for_struct(ty, generics, &s.fields),
50 syn::Data::Enum(e) => derive_for_enum(ty, generics, &e.variants),
51 syn::Data::Union(_) => Err(Error::UnsupportedUnion.with_tokens(&ast)),
52 }
53}
54
55fn derive_for_struct(
56 ty: &Ident,
57 generics: &Generics,
58 fields: &Fields,
59) -> Result<TokenStream, syn::Error> {
60 let cardinality = tuple_cardinality(fields);
61 let first = init_value(ty, None, fields, Direction::Forward);
62 let last = init_value(ty, None, fields, Direction::Backward);
63 let next_body = advance_struct(ty, fields, Direction::Forward);
64 let previous_body = advance_struct(ty, fields, Direction::Backward);
65 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
66 let where_clause = if generics.params.is_empty() {
67 where_clause.cloned()
68 } else {
69 let mut clause = where_clause.cloned().unwrap_or_else(|| WhereClause {
70 where_token: Default::default(),
71 predicates: Default::default(),
72 });
73 clause.predicates.extend(
74 trait_bounds(group_type_requirements(
75 fields.iter().rev().zip(tuple_type_requirements()),
76 ))
77 .map(WherePredicate::Type),
78 );
79 Some(clause)
80 };
81 let tokens = quote! {
82 impl #impl_generics ::enum_iterator::Sequence for #ty #ty_generics #where_clause {
83 #[allow(clippy::identity_op)]
84 const CARDINALITY: usize = #cardinality;
85
86 fn next(&self) -> ::core::option::Option<Self> {
87 #next_body
88 }
89
90 fn previous(&self) -> ::core::option::Option<Self> {
91 #previous_body
92 }
93
94 fn first() -> ::core::option::Option<Self> {
95 #first
96 }
97
98 fn last() -> ::core::option::Option<Self> {
99 #last
100 }
101 }
102 };
103 Ok(tokens)
104}
105
106fn derive_for_enum(
107 ty: &Ident,
108 generics: &Generics,
109 variants: &Punctuated<Variant, Comma>,
110) -> Result<TokenStream, syn::Error> {
111 let cardinality = enum_cardinality(variants);
112 let next_body = advance_enum(ty, variants, Direction::Forward);
113 let previous_body = advance_enum(ty, variants, Direction::Backward);
114 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
115 let where_clause = if generics.params.is_empty() {
116 where_clause.cloned()
117 } else {
118 let mut clause = where_clause.cloned().unwrap_or_else(|| WhereClause {
119 where_token: Default::default(),
120 predicates: Default::default(),
121 });
122 clause.predicates.extend(
123 trait_bounds(group_type_requirements(variants.iter().flat_map(
124 |variant| variant.fields.iter().rev().zip(tuple_type_requirements()),
125 )))
126 .map(WherePredicate::Type),
127 );
128 Some(clause)
129 };
130 let next_variant_body = next_variant(ty, variants, Direction::Forward);
131 let previous_variant_body = next_variant(ty, variants, Direction::Backward);
132 let (first, last) = if variants.is_empty() {
133 (
134 quote! { ::core::option::Option::None },
135 quote! { ::core::option::Option::None },
136 )
137 } else {
138 let last_index = variants.len() - 1;
139 (
140 quote! { next_variant(0) },
141 quote! { previous_variant(#last_index) },
142 )
143 };
144 let tokens = quote! {
145 impl #impl_generics ::enum_iterator::Sequence for #ty #ty_generics #where_clause {
146 #[allow(clippy::identity_op)]
147 const CARDINALITY: usize = #cardinality;
148
149 fn next(&self) -> ::core::option::Option<Self> {
150 #next_body
151 }
152
153 fn previous(&self) -> ::core::option::Option<Self> {
154 #previous_body
155 }
156
157 fn first() -> ::core::option::Option<Self> {
158 #first
159 }
160
161 fn last() -> ::core::option::Option<Self> {
162 #last
163 }
164 }
165
166 fn next_variant #impl_generics(
167 mut i: usize,
168 ) -> ::core::option::Option<#ty #ty_generics> #where_clause {
169 #next_variant_body
170 }
171
172 fn previous_variant #impl_generics(
173 mut i: usize,
174 ) -> ::core::option::Option<#ty #ty_generics> #where_clause {
175 #previous_variant_body
176 }
177 };
178 let tokens = quote! {
179 const _: () = { #tokens };
180 };
181 Ok(tokens)
182}
183
184fn enum_cardinality(variants: &Punctuated<Variant, Comma>) -> TokenStream {
185 let terms = variants
186 .iter()
187 .map(|variant| tuple_cardinality(&variant.fields));
188 quote! {
189 #((#terms) +)* 0
190 }
191}
192
193fn tuple_cardinality(fields: &Fields) -> TokenStream {
194 let factors = fields.iter().map(|field| {
195 let ty = &field.ty;
196 quote! {
197 <#ty as ::enum_iterator::Sequence>::CARDINALITY
198 }
199 });
200 quote! {
201 #(#factors *)* 1
202 }
203}
204
205fn field_id(field: &Field, index: usize) -> Member {
206 field
207 .ident
208 .clone()
209 .map_or_else(|| Member::from(index), Member::from)
210}
211
212fn init_value(
213 ty: &Ident,
214 variant: Option<&Ident>,
215 fields: &Fields,
216 direction: Direction,
217) -> TokenStream {
218 let id = variant.map_or_else(|| quote! { #ty }, |v| quote! { #ty::#v });
219 if fields.is_empty() {
220 quote! {
221 ::core::option::Option::Some(#id {})
222 }
223 } else {
224 let reset = direction.reset();
225 let initialization =
226 repeat(quote! { ::enum_iterator::Sequence::#reset() }).take(fields.len());
227 let assignments = field_assignments(fields);
228 let bindings = bindings().take(fields.len());
229 quote! {{
230 match (#(#initialization,)*) {
231 (#(::core::option::Option::Some(#bindings),)*) => {
232 ::core::option::Option::Some(#id { #assignments })
233 }
234 _ => ::core::option::Option::None,
235 }
236 }}
237 }
238}
239
240fn next_variant(
241 ty: &Ident,
242 variants: &Punctuated<Variant, Comma>,
243 direction: Direction,
244) -> TokenStream {
245 let advance = match direction {
246 Direction::Forward => {
247 let last_index = variants.len().saturating_sub(1);
248 quote! {
249 if i >= #last_index { break ::core::option::Option::None; } else { i+= 1; }
250 }
251 }
252 Direction::Backward => quote! {
253 if i == 0 { break ::core::option::Option::None; } else { i -= 1; }
254 },
255 };
256 let arms = variants.iter().enumerate().map(|(i, v)| {
257 let id = &v.ident;
258 let init = init_value(ty, Some(id), &v.fields, direction);
259 quote! {
260 #i => #init
261 }
262 });
263 quote! {
264 loop {
265 let next = match i {
266 #(#arms,)*
267 _ => ::core::option::Option::None,
268 };
269 match next {
270 ::core::option::Option::Some(_) => break next,
271 ::core::option::Option::None => #advance,
272 }
273 }
274 }
275}
276
277fn advance_struct(ty: &Ident, fields: &Fields, direction: Direction) -> TokenStream {
278 let assignments = field_assignments(fields);
279 let bindings = bindings().take(fields.len()).collect::<Vec<_>>();
280 let tuple = advance_tuple(&bindings, direction);
281 quote! {
282 let #ty { #assignments } = self;
283 let (#(#bindings,)*) = #tuple?;
284 ::core::option::Option::Some(#ty { #assignments })
285 }
286}
287
288fn advance_enum(
289 ty: &Ident,
290 variants: &Punctuated<Variant, Comma>,
291 direction: Direction,
292) -> TokenStream {
293 let arms: Vec<_> = match direction {
294 Direction::Forward => variants
295 .iter()
296 .enumerate()
297 .map(|(i, variant)| advance_enum_arm(ty, direction, i, variant))
298 .collect(),
299 Direction::Backward => variants
300 .iter()
301 .enumerate()
302 .rev()
303 .map(|(i, variant)| advance_enum_arm(ty, direction, i, variant))
304 .collect(),
305 };
306 quote! {
307 match *self {
308 #(#arms,)*
309 }
310 }
311}
312
313fn advance_enum_arm(ty: &Ident, direction: Direction, i: usize, variant: &Variant) -> TokenStream {
314 let next = match direction {
315 Direction::Forward => match i.checked_add(1) {
316 Some(next_i) => quote! { next_variant(#next_i) },
317 None => quote! { ::core::option::Option::None },
318 },
319 Direction::Backward => match i.checked_sub(1) {
320 Some(prev_i) => quote! { previous_variant(#prev_i) },
321 None => quote! { ::core::option::Option::None },
322 },
323 };
324 let id = &variant.ident;
325 if variant.fields.is_empty() {
326 quote! {
327 #ty::#id {} => #next
328 }
329 } else {
330 let destructuring = field_bindings(&variant.fields);
331 let assignments = field_assignments(&variant.fields);
332 let bindings = bindings().take(variant.fields.len()).collect::<Vec<_>>();
333 let tuple = advance_tuple(&bindings, direction);
334 quote! {
335 #ty::#id { #destructuring } => {
336 let y = #tuple;
337 match y {
338 ::core::option::Option::Some((#(#bindings,)*)) => {
339 ::core::option::Option::Some(#ty::#id { #assignments })
340 }
341 ::core::option::Option::None => #next,
342 }
343 }
344 }
345 }
346}
347
348fn advance_tuple(bindings: &[Ident], direction: Direction) -> TokenStream {
349 let advance = direction.advance();
350 let reset = direction.reset();
351 let rev_bindings = bindings.iter().rev().collect::<Vec<_>>();
352 let (rev_binding_head, rev_binding_tail) = match rev_bindings.split_first() {
353 Some((&head, tail)) => (Some(head), tail),
354 None => (None, &*rev_bindings),
355 };
356 let rev_binding_head = match rev_binding_head {
357 Some(head) => quote! {
358 let (#head, carry) = match ::enum_iterator::Sequence::#advance(#head) {
359 ::core::option::Option::Some(#head) => (::core::option::Option::Some(#head), false),
360 ::core::option::Option::None => (::enum_iterator::Sequence::#reset(), true),
361 };
362 },
363 None => quote! {
364 let carry = true;
365 },
366 };
367 let body = quote! {
368 #rev_binding_head
369 #(
370 let (#rev_binding_tail, carry) = if carry {
371 match ::enum_iterator::Sequence::#advance(#rev_binding_tail) {
372 ::core::option::Option::Some(#rev_binding_tail) => {
373 (::core::option::Option::Some(#rev_binding_tail), false)
374 }
375 ::core::option::Option::None => (::enum_iterator::Sequence::#reset(), true),
376 }
377 } else {
378 (
379 ::core::option::Option::Some(::core::clone::Clone::clone(#rev_binding_tail)),
380 false,
381 )
382 };
383 )*
384 if carry {
385 ::core::option::Option::None
386 } else {
387 match (#(#bindings,)*) {
388 (#(::core::option::Option::Some(#bindings),)*) => {
389 ::core::option::Option::Some((#(#bindings,)*))
390 }
391 _ => ::core::option::Option::None,
392 }
393 }
394 };
395 quote! {
396 { #body }
397 }
398}
399
400fn field_assignments<'a, I>(fields: I) -> TokenStream
401where
402 I: IntoIterator<Item = &'a Field>,
403{
404 fields
405 .into_iter()
406 .enumerate()
407 .zip(bindings())
408 .map(|((i, field), binding)| {
409 let field_id = field_id(field, i);
410 quote! { #field_id: #binding, }
411 })
412 .collect()
413}
414
415fn field_bindings<'a, I>(fields: I) -> TokenStream
416where
417 I: IntoIterator<Item = &'a Field>,
418{
419 fields
420 .into_iter()
421 .enumerate()
422 .zip(bindings())
423 .map(|((i, field), binding)| {
424 let field_id = field_id(field, i);
425 quote! { #field_id: ref #binding, }
426 })
427 .collect()
428}
429
430fn bindings() -> impl Iterator<Item = Ident> {
431 (0..).map(|i| Ident::new(&format!("x{i}"), Span::call_site()))
432}
433
434fn trait_bounds<I>(types: I) -> impl Iterator<Item = PredicateType>
435where
436 I: IntoIterator<Item = (Type, TypeRequirements)>,
437{
438 types
439 .into_iter()
440 .map(|(bounded_ty, requirements)| PredicateType {
441 lifetimes: None,
442 bounded_ty,
443 colon_token: Default::default(),
444 bounds: requirements
445 .into_iter()
446 .map(|req| match req {
447 TypeRequirement::Clone => clone_trait_path(),
448 TypeRequirement::Sequence => trait_path(),
449 })
450 .map(trait_bound)
451 .collect(),
452 })
453}
454
455fn trait_bound(path: Path) -> TypeParamBound {
456 TypeParamBound::Trait(TraitBound {
457 paren_token: None,
458 modifier: TraitBoundModifier::None,
459 lifetimes: None,
460 path,
461 })
462}
463
464fn trait_path() -> Path {
465 Path {
466 leading_colon: Some(Default::default()),
467 segments: [
468 PathSegment::from(Ident::new("enum_iterator", Span::call_site())),
469 Ident::new("Sequence", Span::call_site()).into(),
470 ]
471 .into_iter()
472 .collect(),
473 }
474}
475
476fn clone_trait_path() -> Path {
477 Path {
478 leading_colon: Some(Default::default()),
479 segments: [
480 PathSegment::from(Ident::new("core", Span::call_site())),
481 Ident::new("clone", Span::call_site()).into(),
482 Ident::new("Clone", Span::call_site()).into(),
483 ]
484 .into_iter()
485 .collect(),
486 }
487}
488
489fn tuple_type_requirements() -> impl Iterator<Item = TypeRequirements> {
490 once([TypeRequirement::Sequence].into()).chain(repeat(
491 [TypeRequirement::Sequence, TypeRequirement::Clone].into(),
492 ))
493}
494
495fn group_type_requirements<'a, I>(bounds: I) -> Vec<(Type, TypeRequirements)>
496where
497 I: IntoIterator<Item = (&'a Field, TypeRequirements)>,
498{
499 bounds
500 .into_iter()
501 .fold(
502 (HashMap::<_, usize>::new(), Vec::new()),
503 |(mut indexes, mut acc), (field, requirements)| {
504 let i = *indexes.entry(field.ty.clone()).or_insert_with(|| {
505 acc.push((field.ty.clone(), TypeRequirements::new()));
506 acc.len() - 1
507 });
508 acc[i].1.extend(requirements);
509 (indexes, acc)
510 },
511 )
512 .1
513}
514
515#[derive(Clone, Copy, Debug, PartialEq)]
516enum TypeRequirement {
517 Sequence,
518 Clone,
519}
520
521#[derive(Clone, Debug, Default, PartialEq)]
522struct TypeRequirements(u8);
523
524impl TypeRequirements {
525 const SEQUENCE: u8 = 0x1;
526 const CLONE: u8 = 0x2;
527
528 fn new() -> Self {
529 Self::default()
530 }
531
532 fn insert(&mut self, req: TypeRequirement) {
533 self.0 |= Self::enum_to_mask(req);
534 }
535
536 fn into_iter(self) -> impl Iterator<Item = TypeRequirement> {
537 let mut n = self.0;
538 iter::from_fn(move || {
539 if n & Self::SEQUENCE != 0 {
540 n &= !Self::SEQUENCE;
541 Some(TypeRequirement::Sequence)
542 } else if n & Self::CLONE != 0 {
543 n &= !Self::CLONE;
544 Some(TypeRequirement::Clone)
545 } else {
546 None
547 }
548 })
549 }
550
551 fn extend(&mut self, other: Self) {
552 self.0 |= other.0;
553 }
554
555 fn enum_to_mask(req: TypeRequirement) -> u8 {
556 match req {
557 TypeRequirement::Sequence => Self::SEQUENCE,
558 TypeRequirement::Clone => Self::CLONE,
559 }
560 }
561}
562
563impl<const N: usize> From<[TypeRequirement; N]> for TypeRequirements {
564 fn from(reqs: [TypeRequirement; N]) -> Self {
565 reqs.into_iter()
566 .fold(TypeRequirements::new(), |mut acc, req| {
567 acc.insert(req);
568 acc
569 })
570 }
571}
572
573#[derive(Clone, Copy, Debug, Eq, PartialEq)]
574enum Direction {
575 Forward,
576 Backward,
577}
578
579impl Direction {
580 fn advance(self) -> Ident {
581 let s = match self {
582 Direction::Forward => "next",
583 Direction::Backward => "previous",
584 };
585 Ident::new(s, Span::call_site())
586 }
587
588 fn reset(self) -> Ident {
589 let s = match self {
590 Direction::Forward => "first",
591 Direction::Backward => "last",
592 };
593 Ident::new(s, Span::call_site())
594 }
595}
596
597#[derive(Debug)]
598enum Error {
599 UnsupportedUnion,
600}
601
602impl Error {
603 fn with_tokens<T: ToTokens>(self, tokens: T) -> syn::Error {
604 syn::Error::new_spanned(tokens, self)
605 }
606}
607
608impl Display for Error {
609 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
610 match self {
611 Error::UnsupportedUnion => f.write_str("Sequence cannot be derived for union types"),
612 }
613 }
614}