mz_ore_proc/
static_list.rs
1use std::collections::VecDeque;
17
18use proc_macro::TokenStream;
19use proc_macro2::{Punct, Spacing, TokenStream as TokenStream2};
20use quote::{ToTokens, TokenStreamExt, quote};
21use syn::parse::Parse;
22use syn::spanned::Spanned;
23use syn::{
24 Error, Ident, Item, ItemMod, LitInt, LitStr, Token, Type, Visibility, parse_macro_input,
25};
26
27pub fn static_list_impl(args: TokenStream, item: TokenStream) -> TokenStream {
29 let args = parse_macro_input!(args as StaticListArgs);
30 let item = parse_macro_input!(item as ItemMod);
31
32 let static_items = match collect_items(&item, &args.ty) {
33 Ok(items) => items,
34 Err(e) => return e.to_compile_error().into(),
35 };
36
37 let expected_count = match args.expected_count.base10_parse::<usize>() {
39 Ok(c) => c,
40 Err(e) => return e.to_compile_error().into(),
41 };
42 if static_items.len() != expected_count {
43 let msg = format!(
44 "Expected {} items, static list would contain {}",
45 expected_count,
46 static_items.len()
47 );
48 let err = syn::Error::new(item.span(), &msg);
49 return err.to_compile_error().into();
50 }
51
52 let name = syn::Ident::new(&args.name.value(), args.name.span());
53 let ty = syn::Ident::new(&args.ty.value(), args.ty.span());
54
55 let expanded = quote! {
56 pub static #name : &[ &'static #ty ] = &[
57 #static_items
58 ];
59
60 #item
61 };
62
63 expanded.into()
64}
65
66#[derive(Debug)]
67struct StaticItem<'i>(VecDeque<&'i Ident>);
68
69impl<'i> StaticItem<'i> {
70 pub fn new(ident: &'i Ident) -> Self {
71 StaticItem(VecDeque::from([ident]))
72 }
73
74 pub fn to_path(&self) -> syn::Path {
75 syn::Path {
76 leading_colon: None,
77 segments: self
78 .0
79 .iter()
80 .copied()
81 .cloned()
82 .map(|i| syn::PathSegment {
83 ident: i,
84 arguments: syn::PathArguments::None,
85 })
86 .collect(),
87 }
88 }
89}
90
91#[derive(Debug)]
92struct StaticItems<'a>(Vec<StaticItem<'a>>);
93
94impl<'a> StaticItems<'a> {
95 fn len(&self) -> usize {
96 self.0.len()
97 }
98}
99
100impl<'a> IntoIterator for StaticItems<'a> {
101 type Item = StaticItem<'a>;
102 type IntoIter = <Vec<StaticItem<'a>> as IntoIterator>::IntoIter;
103
104 fn into_iter(self) -> Self::IntoIter {
105 self.0.into_iter()
106 }
107}
108
109impl<'a> ToTokens for StaticItems<'a> {
110 fn to_tokens(&self, tokens: &mut TokenStream2) {
111 for path in self.0.iter().map(|item| item.to_path()) {
112 tokens.append(Punct::new('&', Spacing::Joint));
113 path.to_tokens(tokens);
114 tokens.append(Punct::new(',', Spacing::Joint));
115 }
116 }
117}
118
119fn collect_items<'i, 't>(
120 item_mod: &'i ItemMod,
121 expected_ty: &'t LitStr,
122) -> syn::Result<StaticItems<'i>> {
123 let items = item_mod.content.as_ref().map(|c| &c.1[..]).unwrap_or(&[]);
124
125 let Visibility::Public(_) = item_mod.vis else {
127 return Err(Error::new_spanned(
128 item_mod,
129 "Modules in a #[static_list] must be `pub`",
130 ));
131 };
132
133 let mut static_items = Vec::new();
134 for item in items {
135 match item {
136 Item::Static(item_static) if type_matches(&item_static.ty, expected_ty) => {
137 let Visibility::Public(_) = item_static.vis else {
141 return Err(Error::new_spanned(
142 item_static,
143 "All items in a #[static_list] must be `pub`",
144 ));
145 };
146
147 static_items.push(StaticItem::new(&item_static.ident));
148 }
149 Item::Mod(nested_item_mod) => {
150 let nested_items = collect_items(nested_item_mod, expected_ty)?;
151 static_items.extend(nested_items.into_iter());
152 }
153 _ => (),
154 }
155 }
156
157 for static_item in &mut static_items {
158 static_item.0.push_front(&item_mod.ident);
159 }
160
161 Ok(StaticItems(static_items))
162}
163
164#[derive(Debug)]
165struct StaticListArgs {
166 pub ty: LitStr,
168 pub name: LitStr,
173 pub expected_count: LitInt,
175}
176
177impl Parse for StaticListArgs {
178 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
179 let mut ty: Option<LitStr> = None;
180 let mut name: Option<LitStr> = None;
181 let mut expected_count: Option<LitInt> = None;
182
183 while !input.is_empty() {
184 let lookahead = input.lookahead1();
185 if lookahead.peek(keywords::ty) {
186 let res = input.parse::<KeyValueArg<keywords::ty, LitStr>>()?;
187 ty = Some(res.val);
188 } else if lookahead.peek(keywords::name) {
189 let res = input.parse::<KeyValueArg<keywords::name, LitStr>>()?;
190 name = Some(res.val);
191 } else if lookahead.peek(keywords::expected_count) {
192 let res = input.parse::<KeyValueArg<keywords::expected_count, LitInt>>()?;
193 expected_count = Some(res.val);
194 } else if lookahead.peek(Token![,]) {
195 let _ = input.parse::<Token![,]>()?;
196 } else {
197 return Err(input.error("Unexpected argument"));
198 }
199 }
200
201 let mut missing_args = Vec::new();
202 if ty.is_none() {
203 missing_args.push("ty");
204 }
205 if name.is_none() {
206 missing_args.push("name");
207 }
208 if expected_count.is_none() {
209 missing_args.push("expected_count");
210 }
211
212 if !missing_args.is_empty() {
213 input.error(format!("Missing arguments {:?}", missing_args));
214 }
215
216 Ok(StaticListArgs {
217 ty: ty.expect("checked above"),
218 name: name.expect("checked above"),
219 expected_count: expected_count.expect("checked above"),
220 })
221 }
222}
223
224fn type_matches(static_ty: &Type, expected_ty: &LitStr) -> bool {
225 static_ty
227 .into_token_stream()
228 .to_string()
229 .ends_with(&expected_ty.value())
230}
231
232struct KeyValueArg<K, V> {
233 _key: K,
234 val: V,
235}
236
237impl<K: Parse, V: Parse> Parse for KeyValueArg<K, V> {
238 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
239 let key = input.parse::<K>()?;
240 let _ = input.parse::<Token![=]>()?;
241 let val = input.parse::<V>()?;
242
243 Ok(Self { _key: key, val })
244 }
245}
246
247mod keywords {
248 syn::custom_keyword!(ty);
249 syn::custom_keyword!(name);
250 syn::custom_keyword!(expected_count);
251}