derivative/
matcher.rs

1#![allow(dead_code)] // TODO: remove
2
3// This is inspired from `synstructure`, but `synstructure` is not adapted in severals ways
4// including:
5//     * `&mut` everywhere
6//     * not generic, we use our own `ast`, `synstructure` only knows about `syn`
7//     * missing information (what arm are we in?, what attributes? etc.)
8
9use proc_macro2::{self, TokenStream};
10use quote::ToTokens;
11use syn;
12
13use ast;
14use attr;
15use quote;
16
17/// The type of binding to use when generating a pattern.
18#[derive(Debug, Copy, Clone)]
19pub enum BindingStyle {
20    /// `x`
21    Move,
22    /// `mut x`
23    MoveMut,
24    /// `ref x`
25    Ref,
26    /// `ref mut x`
27    RefMut,
28}
29
30impl BindingStyle {
31    fn with_packed(self, is_packed: bool) -> BindingStyle {
32        match self {
33            BindingStyle::Move | BindingStyle::MoveMut => self,
34            BindingStyle::Ref if is_packed => BindingStyle::Move,
35            BindingStyle::RefMut if is_packed => BindingStyle::MoveMut,
36            BindingStyle::Ref | BindingStyle::RefMut => self,
37        }
38    }
39}
40
41impl quote::ToTokens for BindingStyle {
42    fn to_tokens(&self, tokens: &mut TokenStream) {
43        match *self {
44            BindingStyle::Move => (),
45            BindingStyle::MoveMut => tokens.extend(quote!(mut)),
46            BindingStyle::Ref => tokens.extend(quote!(ref)),
47            BindingStyle::RefMut => {
48                tokens.extend(quote!(ref mut));
49            }
50        }
51    }
52}
53
54#[derive(Debug)]
55pub struct BindingInfo<'a> {
56    pub expr: TokenStream,
57    pub ident: syn::Ident,
58    pub field: &'a ast::Field<'a>,
59}
60
61#[derive(Debug)]
62pub struct CommonVariant<'a> {
63    path: syn::Path,
64    name: &'a syn::Ident,
65    style: ast::Style,
66    attrs: &'a attr::Input,
67}
68
69pub struct Matcher<T> {
70    binding_name: String,
71    binding_style: BindingStyle,
72    is_packed: bool,
73    field_filter: T,
74}
75
76impl Matcher<fn (&ast::Field) -> bool> {
77    pub fn new(style: BindingStyle, is_packed: bool) -> Self {
78        Matcher {
79            binding_name: "__arg".into(),
80            binding_style: style.with_packed(is_packed),
81            is_packed,
82            field_filter: |_| true,
83        }
84    }
85}
86
87impl<T: Fn (&ast::Field) -> bool> Matcher<T> {
88    pub fn with_name(self, name: String) -> Self {
89        Matcher {
90            binding_name: name,
91            ..self
92        }
93    }
94
95    pub fn with_field_filter<P>(self, field_filter: P) -> Matcher<P> {
96        Matcher {
97            field_filter,
98            binding_name: self.binding_name,
99            binding_style: self.binding_style,
100            is_packed: self.is_packed,
101        }
102    }
103
104    pub fn build_arms<F>(self, input: &ast::Input, binding_name: &str, f: F) -> TokenStream
105    where
106        F: Fn(
107            syn::Path,
108            usize,
109            &syn::Ident,
110            ast::Style,
111            &attr::Input,
112            Vec<BindingInfo>,
113        ) -> TokenStream,
114    {
115        let variants = self.build_match_pattern(input, binding_name);
116
117        // Now that we have the patterns, generate the actual branches of the match
118        // expression
119        let mut t = TokenStream::new();
120        for (i, (variant, (pat, bindings))) in variants.into_iter().enumerate() {
121            let body = f(
122                variant.path,
123                i,
124                variant.name,
125                variant.style,
126                variant.attrs,
127                bindings,
128            );
129            quote!(#pat => { #body }).to_tokens(&mut t);
130        }
131
132        t
133    }
134
135    pub fn build_2_arms<F>(
136        self,
137        (left_matched_expr, right_matched_expr): (TokenStream, TokenStream),
138        left: (&ast::Input, &str),
139        right: (&ast::Input, &str),
140        f: F,
141    ) -> TokenStream
142    where
143        F: Fn(
144            usize,
145            CommonVariant,
146            CommonVariant,
147            (Vec<BindingInfo>, Vec<BindingInfo>),
148        ) -> TokenStream,
149    {
150        let left_variants = self.build_match_pattern(left.0, left.1);
151        let right_variants = self.build_match_pattern(right.0, right.1);
152
153        assert_eq!(left_variants.len(), right_variants.len());
154
155        if left_variants.len() == 1 {
156            let (left, (left_pat, left_bindings)) = left_variants.into_iter().next().unwrap();
157            let (right, (right_pat, right_bindings)) = right_variants.into_iter().next().unwrap();
158
159            let body = f(0, left, right, (left_bindings, right_bindings));
160
161            quote! {
162                match #left_matched_expr {
163                    #left_pat => match #right_matched_expr {
164                        #right_pat => #body,
165                    },
166                }
167            }
168        } else {
169            // Now that we have the patterns, generate the actual branches of the match
170            // expression
171            let mut t = TokenStream::new();
172            for (i, (left, right)) in left_variants.into_iter().zip(right_variants).enumerate() {
173                let (left, (left_pat, left_bindings)) = left;
174                let (right, (right_pat, right_bindings)) = right;
175
176                let body = f(i, left, right, (left_bindings, right_bindings));
177                quote!((#left_pat, #right_pat) => { #body }).to_tokens(&mut t);
178            }
179
180            quote! {
181                match (&#left_matched_expr, &#right_matched_expr) {
182                    #t
183                    _ => unreachable!(),
184                }
185            }
186        }
187    }
188
189    /// Generate patterns for matching against all of the variants
190    pub fn build_match_pattern<'a>(
191        &self,
192        input: &'a ast::Input,
193        binding_name: &str,
194    ) -> Vec<(CommonVariant<'a>, (TokenStream, Vec<BindingInfo<'a>>))> {
195        let ident = &input.ident;
196
197        match input.body {
198            ast::Body::Enum(ref variants) => variants
199                .iter()
200                .map(|variant| {
201                    let variant_ident = &variant.ident;
202                    let path = parse_quote!(#ident::#variant_ident);
203
204                    let pat = self.build_match_pattern_impl(
205                        &path,
206                        variant.style,
207                        &variant.fields,
208                        binding_name,
209                    );
210
211                    (
212                        CommonVariant {
213                            path,
214                            name: variant_ident,
215                            style: variant.style,
216                            attrs: &variant.attrs,
217                        },
218                        pat,
219                    )
220                })
221                .collect(),
222            ast::Body::Struct(style, ref vd) => {
223                let path = parse_quote!(#ident);
224                vec![(
225                    CommonVariant {
226                        path,
227                        name: ident,
228                        style,
229                        attrs: &input.attrs,
230                    },
231                    self.build_match_pattern_impl(ident, style, vd, binding_name),
232                )]
233            }
234        }
235    }
236
237    fn build_match_pattern_impl<'a, N>(
238        &self,
239        name: &N,
240        style: ast::Style,
241        fields: &'a [ast::Field<'a>],
242        binding_name: &str,
243    ) -> (TokenStream, Vec<BindingInfo<'a>>)
244    where
245        N: quote::ToTokens,
246    {
247        let (stream, matches) = match style {
248            ast::Style::Unit => (TokenStream::new(), Vec::new()),
249            ast::Style::Tuple => {
250                let (stream, matches) = fields.iter().enumerate().fold(
251                    (TokenStream::new(), Vec::new()),
252                    |(stream, matches), field| {
253                        self.build_inner_pattern(
254                            (stream, matches),
255                            field,
256                            binding_name,
257                            |f, ident, binding| {
258                                if (self.field_filter)(f) {
259                                    quote!(#binding #ident ,)
260                                } else {
261                                    quote!(_ ,)
262                                }
263                            },
264                        )
265                    },
266                );
267
268                (quote! { ( #stream ) }, matches)
269            }
270            ast::Style::Struct => {
271                let (stream, matches) = fields.iter().enumerate().fold(
272                    (TokenStream::new(), Vec::new()),
273                    |(stream, matches), field| {
274                        self.build_inner_pattern(
275                            (stream, matches),
276                            field,
277                            binding_name,
278                            |field, ident, binding| {
279                                let field_name = field.ident.as_ref().unwrap();
280                                if (self.field_filter)(field) {
281                                    quote!(#field_name : #binding #ident ,)
282                                } else {
283                                    quote!(#field_name : _ ,)
284                                }
285                            },
286                        )
287                    },
288                );
289
290                (quote! { { #stream } }, matches)
291            }
292        };
293
294        let mut all_tokens = TokenStream::new();
295        name.to_tokens(&mut all_tokens);
296        all_tokens.extend(stream);
297
298        (all_tokens, matches)
299    }
300
301    fn build_inner_pattern<'a>(
302        &self,
303        (mut stream, mut matches): (TokenStream, Vec<BindingInfo<'a>>),
304        (i, field): (usize, &'a ast::Field),
305        binding_name: &str,
306        f: impl FnOnce(&ast::Field, &syn::Ident, BindingStyle) -> TokenStream,
307    ) -> (TokenStream, Vec<BindingInfo<'a>>) {
308        let binding_style = self.binding_style;
309
310        let ident: syn::Ident = syn::Ident::new(
311            &format!("{}_{}", binding_name, i),
312            proc_macro2::Span::call_site(),
313        );
314        let expr = syn::Expr::Path(syn::ExprPath {
315            attrs: vec![],
316            qself: None,
317            path: syn::Path::from(ident.clone())
318        });
319
320        let expr = if self.is_packed {
321            expr.into_token_stream()
322        } else {
323            quote!((*#expr))
324        };
325
326        f(field, &ident, binding_style).to_tokens(&mut stream);
327
328        matches.push(BindingInfo {
329            expr,
330            ident,
331            field,
332        });
333
334        (stream, matches)
335    }
336}