mz_ore_proc/
static_list.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License in the LICENSE file at the
6// root of this repository, or online at
7//
8//     http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use std::collections::VecDeque;
17
18use proc_macro::TokenStream;
19use proc_macro2::{Punct, Spacing, TokenStream as TokenStream2};
20use quote::{ToTokens, TokenStreamExt, quote};
21use syn::parse::Parse;
22use syn::spanned::Spanned;
23use syn::{
24    Error, Ident, Item, ItemMod, LitInt, LitStr, Token, Type, Visibility, parse_macro_input,
25};
26
27/// Implementation for the `#[static_list]` macro.
28pub fn static_list_impl(args: TokenStream, item: TokenStream) -> TokenStream {
29    let args = parse_macro_input!(args as StaticListArgs);
30    let item = parse_macro_input!(item as ItemMod);
31
32    let static_items = match collect_items(&item, &args.ty) {
33        Ok(items) => items,
34        Err(e) => return e.to_compile_error().into(),
35    };
36
37    // Make sure our expected count matches how many items we actually collected.
38    let expected_count = match args.expected_count.base10_parse::<usize>() {
39        Ok(c) => c,
40        Err(e) => return e.to_compile_error().into(),
41    };
42    if static_items.len() != expected_count {
43        let msg = format!(
44            "Expected {} items, static list would contain {}",
45            expected_count,
46            static_items.len()
47        );
48        let err = syn::Error::new(item.span(), &msg);
49        return err.to_compile_error().into();
50    }
51
52    let name = syn::Ident::new(&args.name.value(), args.name.span());
53    let ty = syn::Ident::new(&args.ty.value(), args.ty.span());
54
55    let expanded = quote! {
56        pub static #name : &[ &'static #ty ] = &[
57            #static_items
58        ];
59
60        #item
61    };
62
63    expanded.into()
64}
65
66#[derive(Debug)]
67struct StaticItem<'i>(VecDeque<&'i Ident>);
68
69impl<'i> StaticItem<'i> {
70    pub fn new(ident: &'i Ident) -> Self {
71        StaticItem(VecDeque::from([ident]))
72    }
73
74    pub fn to_path(&self) -> syn::Path {
75        syn::Path {
76            leading_colon: None,
77            segments: self
78                .0
79                .iter()
80                .copied()
81                .cloned()
82                .map(|i| syn::PathSegment {
83                    ident: i,
84                    arguments: syn::PathArguments::None,
85                })
86                .collect(),
87        }
88    }
89}
90
91#[derive(Debug)]
92struct StaticItems<'a>(Vec<StaticItem<'a>>);
93
94impl<'a> StaticItems<'a> {
95    fn len(&self) -> usize {
96        self.0.len()
97    }
98}
99
100impl<'a> IntoIterator for StaticItems<'a> {
101    type Item = StaticItem<'a>;
102    type IntoIter = <Vec<StaticItem<'a>> as IntoIterator>::IntoIter;
103
104    fn into_iter(self) -> Self::IntoIter {
105        self.0.into_iter()
106    }
107}
108
109impl<'a> ToTokens for StaticItems<'a> {
110    fn to_tokens(&self, tokens: &mut TokenStream2) {
111        for path in self.0.iter().map(|item| item.to_path()) {
112            tokens.append(Punct::new('&', Spacing::Joint));
113            path.to_tokens(tokens);
114            tokens.append(Punct::new(',', Spacing::Joint));
115        }
116    }
117}
118
119fn collect_items<'i, 't>(
120    item_mod: &'i ItemMod,
121    expected_ty: &'t LitStr,
122) -> syn::Result<StaticItems<'i>> {
123    let items = item_mod.content.as_ref().map(|c| &c.1[..]).unwrap_or(&[]);
124
125    // TODO(parkmycar): Support ignoring modules.
126    let Visibility::Public(_) = item_mod.vis else {
127        return Err(Error::new_spanned(
128            item_mod,
129            "Modules in a #[static_list] must be `pub`",
130        ));
131    };
132
133    let mut static_items = Vec::new();
134    for item in items {
135        match item {
136            Item::Static(item_static) if type_matches(&item_static.ty, expected_ty) => {
137                // For ease of use, we error if a static item isn't public.
138                //
139                // TODO(parkmycar): Support ignoring items.
140                let Visibility::Public(_) = item_static.vis else {
141                    return Err(Error::new_spanned(
142                        item_static,
143                        "All items in a #[static_list] must be `pub`",
144                    ));
145                };
146
147                static_items.push(StaticItem::new(&item_static.ident));
148            }
149            Item::Mod(nested_item_mod) => {
150                let nested_items = collect_items(nested_item_mod, expected_ty)?;
151                static_items.extend(nested_items.into_iter());
152            }
153            _ => (),
154        }
155    }
156
157    for static_item in &mut static_items {
158        static_item.0.push_front(&item_mod.ident);
159    }
160
161    Ok(StaticItems(static_items))
162}
163
164#[derive(Debug)]
165struct StaticListArgs {
166    /// Type of static objects we should use to form the list.
167    pub ty: LitStr,
168    /// Name we should use for the list.
169    ///
170    /// Note: Requring the exact name, instead of some smart default, makes it a lot easier to
171    /// discover where the list is defined using tools like grep.
172    pub name: LitStr,
173    /// Expected count of items, used as a smoke test.
174    pub expected_count: LitInt,
175}
176
177impl Parse for StaticListArgs {
178    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
179        let mut ty: Option<LitStr> = None;
180        let mut name: Option<LitStr> = None;
181        let mut expected_count: Option<LitInt> = None;
182
183        while !input.is_empty() {
184            let lookahead = input.lookahead1();
185            if lookahead.peek(keywords::ty) {
186                let res = input.parse::<KeyValueArg<keywords::ty, LitStr>>()?;
187                ty = Some(res.val);
188            } else if lookahead.peek(keywords::name) {
189                let res = input.parse::<KeyValueArg<keywords::name, LitStr>>()?;
190                name = Some(res.val);
191            } else if lookahead.peek(keywords::expected_count) {
192                let res = input.parse::<KeyValueArg<keywords::expected_count, LitInt>>()?;
193                expected_count = Some(res.val);
194            } else if lookahead.peek(Token![,]) {
195                let _ = input.parse::<Token![,]>()?;
196            } else {
197                return Err(input.error("Unexpected argument"));
198            }
199        }
200
201        let mut missing_args = Vec::new();
202        if ty.is_none() {
203            missing_args.push("ty");
204        }
205        if name.is_none() {
206            missing_args.push("name");
207        }
208        if expected_count.is_none() {
209            missing_args.push("expected_count");
210        }
211
212        if !missing_args.is_empty() {
213            input.error(format!("Missing arguments {:?}", missing_args));
214        }
215
216        Ok(StaticListArgs {
217            ty: ty.expect("checked above"),
218            name: name.expect("checked above"),
219            expected_count: expected_count.expect("checked above"),
220        })
221    }
222}
223
224fn type_matches(static_ty: &Type, expected_ty: &LitStr) -> bool {
225    // Note: This probably isn't super accurrate, but it's easy.
226    static_ty
227        .into_token_stream()
228        .to_string()
229        .ends_with(&expected_ty.value())
230}
231
232struct KeyValueArg<K, V> {
233    _key: K,
234    val: V,
235}
236
237impl<K: Parse, V: Parse> Parse for KeyValueArg<K, V> {
238    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
239        let key = input.parse::<K>()?;
240        let _ = input.parse::<Token![=]>()?;
241        let val = input.parse::<V>()?;
242
243        Ok(Self { _key: key, val })
244    }
245}
246
247mod keywords {
248    syn::custom_keyword!(ty);
249    syn::custom_keyword!(name);
250    syn::custom_keyword!(expected_count);
251}