const_fn/
ast.rs

1// SPDX-License-Identifier: Apache-2.0 OR MIT
2
3use proc_macro::{Delimiter, Ident, Literal, Span, TokenStream, TokenTree};
4
5use crate::{iter::TokenIter, to_tokens::ToTokens, utils::tt_span, Result};
6
7pub(crate) struct Func {
8    attrs: Vec<Attribute>,
9    // [const] [async] [unsafe] [extern [<abi>]] fn
10    sig: Vec<TokenTree>,
11    body: TokenStream,
12    pub(crate) print_const: bool,
13}
14
15pub(crate) fn parse_input(input: TokenStream) -> Result<Func> {
16    let input = &mut TokenIter::new(input);
17
18    let attrs = parse_attrs(input)?;
19    let sig = parse_signature(input);
20    let body: TokenStream = input.collect();
21
22    if body.is_empty()
23        || !sig
24            .iter()
25            .any(|tt| if let TokenTree::Ident(i) = tt { i.to_string() == "fn" } else { false })
26    {
27        bail!(Span::call_site(), "#[const_fn] attribute may only be used on functions");
28    }
29
30    Ok(Func { attrs, sig, body, print_const: true })
31}
32
33impl ToTokens for Func {
34    fn to_tokens(&self, tokens: &mut TokenStream) {
35        self.attrs.iter().for_each(|attr| attr.to_tokens(tokens));
36        if self.print_const {
37            self.sig.iter().for_each(|attr| attr.to_tokens(tokens));
38        } else {
39            self.sig
40                .iter()
41                .filter(
42                    |tt| if let TokenTree::Ident(i) = tt { i.to_string() != "const" } else { true },
43                )
44                .for_each(|tt| tt.to_tokens(tokens));
45        }
46        self.body.to_tokens(tokens);
47    }
48}
49
50fn parse_signature(input: &mut TokenIter) -> Vec<TokenTree> {
51    let mut sig = vec![];
52    let mut has_const = false;
53    loop {
54        match input.peek() {
55            None => break,
56            Some(TokenTree::Ident(i)) if !has_const => {
57                match &*i.to_string() {
58                    "fn" => {
59                        sig.push(TokenTree::Ident(Ident::new("const", i.span())));
60                        sig.push(input.next().unwrap());
61                        break;
62                    }
63                    "const" => {
64                        has_const = true;
65                    }
66                    "async" | "unsafe" | "extern" => {
67                        has_const = true;
68                        sig.push(TokenTree::Ident(Ident::new("const", i.span())));
69                    }
70                    _ => {}
71                }
72                sig.push(input.next().unwrap());
73            }
74            Some(TokenTree::Ident(i)) if i.to_string() == "fn" => {
75                sig.push(input.next().unwrap());
76                break;
77            }
78            Some(_) => sig.push(input.next().unwrap()),
79        }
80    }
81    sig
82}
83
84fn parse_attrs(input: &mut TokenIter) -> Result<Vec<Attribute>> {
85    let mut attrs = vec![];
86    loop {
87        let pound_token = match input.peek() {
88            Some(TokenTree::Punct(p)) if p.as_char() == '#' => input.next().unwrap(),
89            _ => break,
90        };
91        let group = match input.peek() {
92            Some(TokenTree::Group(g)) if g.delimiter() == Delimiter::Bracket => {
93                input.next().unwrap()
94            }
95            tt => bail!(tt_span(tt), "expected `[`"),
96        };
97        attrs.push(Attribute { pound_token, group });
98    }
99    Ok(attrs)
100}
101
102pub(crate) struct Attribute {
103    // `#`
104    pub(crate) pound_token: TokenTree,
105    // `[...]`
106    pub(crate) group: TokenTree,
107}
108
109impl ToTokens for Attribute {
110    fn to_tokens(&self, tokens: &mut TokenStream) {
111        self.pound_token.to_tokens(tokens);
112        self.group.to_tokens(tokens);
113    }
114}
115
116pub(crate) struct LitStr {
117    pub(crate) token: Literal,
118    value: String,
119}
120
121impl LitStr {
122    pub(crate) fn new(token: Literal) -> Result<Self> {
123        let value = token.to_string();
124        // unlike `syn::LitStr`, only accepts `"..."`
125        if value.starts_with('"') && value.ends_with('"') {
126            Ok(Self { token, value })
127        } else {
128            bail!(token.span(), "expected string literal")
129        }
130    }
131
132    pub(crate) fn value(&self) -> &str {
133        &self.value[1..self.value.len() - 1]
134    }
135
136    pub(crate) fn span(&self) -> Span {
137        self.token.span()
138    }
139}