seq_macro/
lib.rs

1//! [![github]](https://github.com/dtolnay/seq-macro) [![crates-io]](https://crates.io/crates/seq-macro) [![docs-rs]](https://docs.rs/seq-macro)
2//!
3//! [github]: https://img.shields.io/badge/github-8da0cb?style=for-the-badge&labelColor=555555&logo=github
4//! [crates-io]: https://img.shields.io/badge/crates.io-fc8d62?style=for-the-badge&labelColor=555555&logo=rust
5//! [docs-rs]: https://img.shields.io/badge/docs.rs-66c2a5?style=for-the-badge&labelColor=555555&logo=docs.rs
6//!
7//! <br>
8//!
9//! # Imagine for-loops in a macro
10//!
11//! This crate provides a `seq!` macro to repeat a fragment of source code and
12//! substitute into each repetition a sequential numeric counter.
13//!
14//! ```
15//! use seq_macro::seq;
16//!
17//! fn main() {
18//!     let tuple = (1000, 100, 10);
19//!     let mut sum = 0;
20//!
21//!     // Expands to:
22//!     //
23//!     //     sum += tuple.0;
24//!     //     sum += tuple.1;
25//!     //     sum += tuple.2;
26//!     //
27//!     // This cannot be written using an ordinary for-loop because elements of
28//!     // a tuple can only be accessed by their integer literal index, not by a
29//!     // variable.
30//!     seq!(N in 0..=2 {
31//!         sum += tuple.N;
32//!     });
33//!
34//!     assert_eq!(sum, 1110);
35//! }
36//! ```
37//!
38//! - If the input tokens contain a section surrounded by `#(` ... `)*` then
39//!   only that part is repeated.
40//!
41//! - The numeric counter can be pasted onto the end of some prefix to form
42//!   sequential identifiers.
43//!
44//! ```
45//! use seq_macro::seq;
46//!
47//! seq!(N in 64..=127 {
48//!     #[derive(Debug)]
49//!     enum Demo {
50//!         // Expands to Variant64, Variant65, ...
51//!         ##(
52//!             Variant~N,
53//!         )*
54//!     }
55//! });
56//!
57//! fn main() {
58//!     assert_eq!("Variant99", format!("{:?}", Demo::Variant99));
59//! }
60//! ```
61//!
62//! - Byte and character ranges are supported: `b'a'..=b'z'`, `'a'..='z'`.
63//!
64//! - If the range bounds are written in binary, octal, hex, or with zero
65//!   padding, those features are preserved in any generated tokens.
66//!
67//! ```
68//! use seq_macro::seq;
69//!
70//! seq!(P in 0x000..=0x00F {
71//!     // expands to structs Pin000, ..., Pin009, Pin00A, ..., Pin00F
72//!     struct Pin~P;
73//! });
74//! ```
75
76#![allow(
77    clippy::cast_lossless,
78    clippy::cast_possible_truncation,
79    clippy::derive_partial_eq_without_eq,
80    clippy::needless_doctest_main,
81    clippy::single_match_else,
82    clippy::wildcard_imports
83)]
84
85mod parse;
86
87use crate::parse::*;
88use proc_macro::{Delimiter, Group, Ident, Literal, Span, TokenStream, TokenTree};
89use std::char;
90use std::iter::{self, FromIterator};
91
92#[proc_macro]
93pub fn seq(input: TokenStream) -> TokenStream {
94    match seq_impl(input) {
95        Ok(expanded) => expanded,
96        Err(error) => error.into_compile_error(),
97    }
98}
99
100struct Range {
101    begin: u64,
102    end: u64,
103    inclusive: bool,
104    kind: Kind,
105    suffix: String,
106    width: usize,
107    radix: Radix,
108}
109
110struct Value {
111    int: u64,
112    kind: Kind,
113    suffix: String,
114    width: usize,
115    radix: Radix,
116    span: Span,
117}
118
119struct Splice<'a> {
120    int: u64,
121    kind: Kind,
122    suffix: &'a str,
123    width: usize,
124    radix: Radix,
125}
126
127#[derive(Copy, Clone, PartialEq)]
128enum Kind {
129    Int,
130    Byte,
131    Char,
132}
133
134#[derive(Copy, Clone, PartialEq)]
135enum Radix {
136    Binary,
137    Octal,
138    Decimal,
139    LowerHex,
140    UpperHex,
141}
142
143impl<'a> IntoIterator for &'a Range {
144    type Item = Splice<'a>;
145    type IntoIter = Box<dyn Iterator<Item = Splice<'a>> + 'a>;
146
147    fn into_iter(self) -> Self::IntoIter {
148        let splice = move |int| Splice {
149            int,
150            kind: self.kind,
151            suffix: &self.suffix,
152            width: self.width,
153            radix: self.radix,
154        };
155        match self.kind {
156            Kind::Int | Kind::Byte => {
157                if self.inclusive {
158                    Box::new((self.begin..=self.end).map(splice))
159                } else {
160                    Box::new((self.begin..self.end).map(splice))
161                }
162            }
163            Kind::Char => {
164                let begin = char::from_u32(self.begin as u32).unwrap();
165                let end = char::from_u32(self.end as u32).unwrap();
166                let int = |ch| u64::from(u32::from(ch));
167                if self.inclusive {
168                    Box::new((begin..=end).map(int).map(splice))
169                } else {
170                    Box::new((begin..end).map(int).map(splice))
171                }
172            }
173        }
174    }
175}
176
177fn seq_impl(input: TokenStream) -> Result<TokenStream, SyntaxError> {
178    let mut iter = input.into_iter();
179    let var = require_ident(&mut iter)?;
180    require_keyword(&mut iter, "in")?;
181    let begin = require_value(&mut iter)?;
182    require_punct(&mut iter, '.')?;
183    require_punct(&mut iter, '.')?;
184    let inclusive = require_if_punct(&mut iter, '=')?;
185    let end = require_value(&mut iter)?;
186    let body = require_braces(&mut iter)?;
187    require_end(&mut iter)?;
188
189    let range = validate_range(begin, end, inclusive)?;
190
191    let mut found_repetition = false;
192    let expanded = expand_repetitions(&var, &range, body.clone(), &mut found_repetition);
193    if found_repetition {
194        Ok(expanded)
195    } else {
196        // If no `#(...)*`, repeat the entire body.
197        Ok(repeat(&var, &range, &body))
198    }
199}
200
201fn repeat(var: &Ident, range: &Range, body: &TokenStream) -> TokenStream {
202    let mut repeated = TokenStream::new();
203    for value in range {
204        repeated.extend(substitute_value(var, &value, body.clone()));
205    }
206    repeated
207}
208
209fn substitute_value(var: &Ident, splice: &Splice, body: TokenStream) -> TokenStream {
210    let mut tokens = Vec::from_iter(body);
211
212    let mut i = 0;
213    while i < tokens.len() {
214        // Substitute our variable by itself, e.g. `N`.
215        let replace = match &tokens[i] {
216            TokenTree::Ident(ident) => ident.to_string() == var.to_string(),
217            _ => false,
218        };
219        if replace {
220            let original_span = tokens[i].span();
221            let mut literal = splice.literal();
222            literal.set_span(original_span);
223            tokens[i] = TokenTree::Literal(literal);
224            i += 1;
225            continue;
226        }
227
228        // Substitute our variable concatenated onto some prefix, `Prefix~N`.
229        if i + 3 <= tokens.len() {
230            let prefix = match &tokens[i..i + 3] {
231                [first, TokenTree::Punct(tilde), TokenTree::Ident(ident)]
232                    if tilde.as_char() == '~' && ident.to_string() == var.to_string() =>
233                {
234                    match first {
235                        TokenTree::Ident(ident) => Some(ident.clone()),
236                        TokenTree::Group(group) => {
237                            let mut iter = group.stream().into_iter().fuse();
238                            match (iter.next(), iter.next()) {
239                                (Some(TokenTree::Ident(ident)), None) => Some(ident),
240                                _ => None,
241                            }
242                        }
243                        _ => None,
244                    }
245                }
246                _ => None,
247            };
248            if let Some(prefix) = prefix {
249                let number = match splice.kind {
250                    Kind::Int => match splice.radix {
251                        Radix::Binary => format!("{0:01$b}", splice.int, splice.width),
252                        Radix::Octal => format!("{0:01$o}", splice.int, splice.width),
253                        Radix::Decimal => format!("{0:01$}", splice.int, splice.width),
254                        Radix::LowerHex => format!("{0:01$x}", splice.int, splice.width),
255                        Radix::UpperHex => format!("{0:01$X}", splice.int, splice.width),
256                    },
257                    Kind::Byte | Kind::Char => {
258                        char::from_u32(splice.int as u32).unwrap().to_string()
259                    }
260                };
261                let concat = format!("{}{}", prefix, number);
262                let ident = Ident::new(&concat, prefix.span());
263                tokens.splice(i..i + 3, iter::once(TokenTree::Ident(ident)));
264                i += 1;
265                continue;
266            }
267        }
268
269        // Recursively substitute content nested in a group.
270        if let TokenTree::Group(group) = &mut tokens[i] {
271            let original_span = group.span();
272            let content = substitute_value(var, splice, group.stream());
273            *group = Group::new(group.delimiter(), content);
274            group.set_span(original_span);
275        }
276
277        i += 1;
278    }
279
280    TokenStream::from_iter(tokens)
281}
282
283fn enter_repetition(tokens: &[TokenTree]) -> Option<TokenStream> {
284    assert!(tokens.len() == 3);
285    match &tokens[0] {
286        TokenTree::Punct(punct) if punct.as_char() == '#' => {}
287        _ => return None,
288    }
289    match &tokens[2] {
290        TokenTree::Punct(punct) if punct.as_char() == '*' => {}
291        _ => return None,
292    }
293    match &tokens[1] {
294        TokenTree::Group(group) if group.delimiter() == Delimiter::Parenthesis => {
295            Some(group.stream())
296        }
297        _ => None,
298    }
299}
300
301fn expand_repetitions(
302    var: &Ident,
303    range: &Range,
304    body: TokenStream,
305    found_repetition: &mut bool,
306) -> TokenStream {
307    let mut tokens = Vec::from_iter(body);
308
309    // Look for `#(...)*`.
310    let mut i = 0;
311    while i < tokens.len() {
312        if let TokenTree::Group(group) = &mut tokens[i] {
313            let content = expand_repetitions(var, range, group.stream(), found_repetition);
314            let original_span = group.span();
315            *group = Group::new(group.delimiter(), content);
316            group.set_span(original_span);
317            i += 1;
318            continue;
319        }
320        if i + 3 > tokens.len() {
321            i += 1;
322            continue;
323        }
324        let template = match enter_repetition(&tokens[i..i + 3]) {
325            Some(template) => template,
326            None => {
327                i += 1;
328                continue;
329            }
330        };
331        *found_repetition = true;
332        let mut repeated = Vec::new();
333        for value in range {
334            repeated.extend(substitute_value(var, &value, template.clone()));
335        }
336        let repeated_len = repeated.len();
337        tokens.splice(i..i + 3, repeated);
338        i += repeated_len;
339    }
340
341    TokenStream::from_iter(tokens)
342}
343
344impl Splice<'_> {
345    fn literal(&self) -> Literal {
346        match self.kind {
347            Kind::Int | Kind::Byte => {
348                let repr = match self.radix {
349                    Radix::Binary => format!("0b{0:02$b}{1}", self.int, self.suffix, self.width),
350                    Radix::Octal => format!("0o{0:02$o}{1}", self.int, self.suffix, self.width),
351                    Radix::Decimal => format!("{0:02$}{1}", self.int, self.suffix, self.width),
352                    Radix::LowerHex => format!("0x{0:02$x}{1}", self.int, self.suffix, self.width),
353                    Radix::UpperHex => format!("0x{0:02$X}{1}", self.int, self.suffix, self.width),
354                };
355                let tokens = repr.parse::<TokenStream>().unwrap();
356                let mut iter = tokens.into_iter();
357                let literal = match iter.next() {
358                    Some(TokenTree::Literal(literal)) => literal,
359                    _ => unreachable!(),
360                };
361                assert!(iter.next().is_none());
362                literal
363            }
364            Kind::Char => {
365                let ch = char::from_u32(self.int as u32).unwrap();
366                Literal::character(ch)
367            }
368        }
369    }
370}