derivative/
cmp.rs

1// https://github.com/rust-lang/rust/issues/13101
2
3use ast;
4use attr;
5use matcher;
6use paths;
7use proc_macro2;
8use syn;
9use utils;
10
11/// Derive `Eq` for `input`.
12pub fn derive_eq(input: &ast::Input) -> proc_macro2::TokenStream {
13    let name = &input.ident;
14
15    let eq_trait_path = eq_trait_path();
16    let generics = utils::build_impl_generics(
17        input,
18        &eq_trait_path,
19        needs_eq_bound,
20        |field| field.eq_bound(),
21        |input| input.eq_bound(),
22    );
23    let new_where_clause;
24    let (impl_generics, ty_generics, mut where_clause) = generics.split_for_impl();
25
26    if let Some(new_where_clause2) =
27        maybe_add_copy(input, where_clause, |f| !f.attrs.ignore_partial_eq())
28    {
29        new_where_clause = new_where_clause2;
30        where_clause = Some(&new_where_clause);
31    }
32
33    quote! {
34        #[allow(unused_qualifications)]
35        impl #impl_generics #eq_trait_path for #name #ty_generics #where_clause {}
36    }
37}
38
39/// Derive `PartialEq` for `input`.
40pub fn derive_partial_eq(input: &ast::Input) -> proc_macro2::TokenStream {
41    let discriminant_cmp = if let ast::Body::Enum(_) = input.body {
42        let discriminant_path = paths::discriminant_path();
43
44        quote!((#discriminant_path(&*self) == #discriminant_path(&*other)))
45    } else {
46        quote!(true)
47    };
48
49    let name = &input.ident;
50
51    let partial_eq_trait_path = partial_eq_trait_path();
52    let generics = utils::build_impl_generics(
53        input,
54        &partial_eq_trait_path,
55        needs_partial_eq_bound,
56        |field| field.partial_eq_bound(),
57        |input| input.partial_eq_bound(),
58    );
59    let new_where_clause;
60    let (impl_generics, ty_generics, mut where_clause) = generics.split_for_impl();
61
62    let match_fields = if input.is_trivial_enum() {
63        quote!(true)
64    } else {
65        matcher::Matcher::new(matcher::BindingStyle::Ref, input.attrs.is_packed)
66            .with_field_filter(|f: &ast::Field| !f.attrs.ignore_partial_eq())
67            .build_2_arms(
68                (quote!(*self), quote!(*other)),
69                (input, "__self"),
70                (input, "__other"),
71                |_, _, _, (left_variant, right_variant)| {
72                    let cmp = left_variant.iter().zip(&right_variant).map(|(o, i)| {
73                        let outer_name = &o.expr;
74                        let inner_name = &i.expr;
75
76                        if o.field.attrs.ignore_partial_eq() {
77                            None
78                        } else if let Some(compare_fn) = o.field.attrs.partial_eq_compare_with() {
79                            Some(quote!(&& #compare_fn(&#outer_name, &#inner_name)))
80                        } else {
81                            Some(quote!(&& &#outer_name == &#inner_name))
82                        }
83                    });
84
85                    quote!(true #(#cmp)*)
86                },
87            )
88    };
89
90    if let Some(new_where_clause2) =
91        maybe_add_copy(input, where_clause, |f| !f.attrs.ignore_partial_eq())
92    {
93        new_where_clause = new_where_clause2;
94        where_clause = Some(&new_where_clause);
95    }
96
97    quote! {
98        #[allow(unused_qualifications)]
99        #[allow(clippy::unneeded_field_pattern)]
100        impl #impl_generics #partial_eq_trait_path for #name #ty_generics #where_clause {
101            fn eq(&self, other: &Self) -> bool {
102                #discriminant_cmp && #match_fields
103            }
104        }
105    }
106}
107
108/// Derive `PartialOrd` for `input`.
109pub fn derive_partial_ord(
110    input: &ast::Input,
111    errors: &mut proc_macro2::TokenStream,
112) -> proc_macro2::TokenStream {
113    if let ast::Body::Enum(_) = input.body {
114        if !input.attrs.partial_ord_on_enum() {
115            let message = "can't use `#[derivative(PartialOrd)]` on an enumeration without \
116            `feature_allow_slow_enum`; see the documentation for more details";
117            errors.extend(syn::Error::new(input.span, message).to_compile_error());
118        }
119    }
120
121    let option_path = option_path();
122    let ordering_path = ordering_path();
123
124    let body = matcher::Matcher::new(matcher::BindingStyle::Ref, input.attrs.is_packed)
125        .with_field_filter(|f: &ast::Field| !f.attrs.ignore_partial_ord())
126        .build_arms(input, "__self", |_, n, _, _, _, outer_bis| {
127            let body = matcher::Matcher::new(matcher::BindingStyle::Ref, input.attrs.is_packed)
128                .with_field_filter(|f: &ast::Field| !f.attrs.ignore_partial_ord())
129                .build_arms(input, "__other", |_, m, _, _, _, inner_bis| {
130                    match n.cmp(&m) {
131                        ::std::cmp::Ordering::Less => {
132                            quote!(#option_path::Some(#ordering_path::Less))
133                        }
134                        ::std::cmp::Ordering::Greater => {
135                            quote!(#option_path::Some(#ordering_path::Greater))
136                        }
137                        ::std::cmp::Ordering::Equal => {
138                            let equal_path = quote!(#ordering_path::Equal);
139                            outer_bis
140                                .iter()
141                                .rev()
142                                .zip(inner_bis.into_iter().rev())
143                                .fold(quote!(#option_path::Some(#equal_path)), |acc, (o, i)| {
144                                    let outer_name = &o.expr;
145                                    let inner_name = &i.expr;
146
147                                    if o.field.attrs.ignore_partial_ord() {
148                                        acc
149                                    } else {
150                                        let cmp_fn = o
151                                            .field
152                                            .attrs
153                                            .partial_ord_compare_with()
154                                            .map(|f| quote!(#f))
155                                            .unwrap_or_else(|| {
156                                                let path = partial_ord_trait_path();
157                                                quote!(#path::partial_cmp)
158                                            });
159
160                                        quote!(match #cmp_fn(&#outer_name, &#inner_name) {
161                                            #option_path::Some(#equal_path) => #acc,
162                                            __derive_ordering_other => __derive_ordering_other,
163                                        })
164                                    }
165                                })
166                        }
167                    }
168                });
169
170            quote! {
171                match *other {
172                    #body
173                }
174
175            }
176        });
177
178    let name = &input.ident;
179
180    let partial_ord_trait_path = partial_ord_trait_path();
181    let generics = utils::build_impl_generics(
182        input,
183        &partial_ord_trait_path,
184        needs_partial_ord_bound,
185        |field| field.partial_ord_bound(),
186        |input| input.partial_ord_bound(),
187    );
188    let new_where_clause;
189    let (impl_generics, ty_generics, mut where_clause) = generics.split_for_impl();
190
191    if let Some(new_where_clause2) =
192        maybe_add_copy(input, where_clause, |f| !f.attrs.ignore_partial_ord())
193    {
194        new_where_clause = new_where_clause2;
195        where_clause = Some(&new_where_clause);
196    }
197
198    quote! {
199        #[allow(unused_qualifications)]
200        #[allow(clippy::unneeded_field_pattern)]
201        impl #impl_generics #partial_ord_trait_path for #name #ty_generics #where_clause {
202            fn partial_cmp(&self, other: &Self) -> #option_path<#ordering_path> {
203                match *self {
204                    #body
205                }
206            }
207        }
208    }
209}
210
211/// Derive `Ord` for `input`.
212pub fn derive_ord(
213    input: &ast::Input,
214    errors: &mut proc_macro2::TokenStream,
215) -> proc_macro2::TokenStream {
216    if let ast::Body::Enum(_) = input.body {
217        if !input.attrs.ord_on_enum() {
218            let message = "can't use `#[derivative(Ord)]` on an enumeration without \
219            `feature_allow_slow_enum`; see the documentation for more details";
220            errors.extend(syn::Error::new(input.span, message).to_compile_error());
221        }
222    }
223
224    let ordering_path = ordering_path();
225
226    let body = matcher::Matcher::new(matcher::BindingStyle::Ref, input.attrs.is_packed)
227        .with_field_filter(|f: &ast::Field| !f.attrs.ignore_ord())
228        .build_arms(input, "__self", |_, n, _, _, _, outer_bis| {
229            let body = matcher::Matcher::new(matcher::BindingStyle::Ref, input.attrs.is_packed)
230                .with_field_filter(|f: &ast::Field| !f.attrs.ignore_ord())
231                .build_arms(input, "__other", |_, m, _, _, _, inner_bis| {
232                    match n.cmp(&m) {
233                        ::std::cmp::Ordering::Less => quote!(#ordering_path::Less),
234                        ::std::cmp::Ordering::Greater => quote!(#ordering_path::Greater),
235                        ::std::cmp::Ordering::Equal => {
236                            let equal_path = quote!(#ordering_path::Equal);
237                            outer_bis
238                                .iter()
239                                .rev()
240                                .zip(inner_bis.into_iter().rev())
241                                .fold(quote!(#equal_path), |acc, (o, i)| {
242                                    let outer_name = &o.expr;
243                                    let inner_name = &i.expr;
244
245                                    if o.field.attrs.ignore_ord() {
246                                        acc
247                                    } else {
248                                        let cmp_fn = o
249                                            .field
250                                            .attrs
251                                            .ord_compare_with()
252                                            .map(|f| quote!(#f))
253                                            .unwrap_or_else(|| {
254                                                let path = ord_trait_path();
255                                                quote!(#path::cmp)
256                                            });
257
258                                        quote!(match #cmp_fn(&#outer_name, &#inner_name) {
259                                           #equal_path => #acc,
260                                            __derive_ordering_other => __derive_ordering_other,
261                                        })
262                                    }
263                                })
264                        }
265                    }
266                });
267
268            quote! {
269                match *other {
270                    #body
271                }
272
273            }
274        });
275
276    let name = &input.ident;
277
278    let ord_trait_path = ord_trait_path();
279    let generics = utils::build_impl_generics(
280        input,
281        &ord_trait_path,
282        needs_ord_bound,
283        |field| field.ord_bound(),
284        |input| input.ord_bound(),
285    );
286    let new_where_clause;
287    let (impl_generics, ty_generics, mut where_clause) = generics.split_for_impl();
288
289    if let Some(new_where_clause2) = maybe_add_copy(input, where_clause, |f| !f.attrs.ignore_ord())
290    {
291        new_where_clause = new_where_clause2;
292        where_clause = Some(&new_where_clause);
293    }
294
295    quote! {
296        #[allow(unused_qualifications)]
297        #[allow(clippy::unneeded_field_pattern)]
298        impl #impl_generics #ord_trait_path for #name #ty_generics #where_clause {
299            fn cmp(&self, other: &Self) -> #ordering_path {
300                match *self {
301                    #body
302                }
303            }
304        }
305    }
306}
307
308fn needs_partial_eq_bound(attrs: &attr::Field) -> bool {
309    !attrs.ignore_partial_eq() && attrs.partial_eq_bound().is_none()
310}
311
312fn needs_partial_ord_bound(attrs: &attr::Field) -> bool {
313    !attrs.ignore_partial_ord() && attrs.partial_ord_bound().is_none()
314}
315
316fn needs_ord_bound(attrs: &attr::Field) -> bool {
317    !attrs.ignore_ord() && attrs.ord_bound().is_none()
318}
319
320fn needs_eq_bound(attrs: &attr::Field) -> bool {
321    !attrs.ignore_partial_eq() && attrs.eq_bound().is_none()
322}
323
324/// Return the path of the `Eq` trait, that is `::std::cmp::Eq`.
325fn eq_trait_path() -> syn::Path {
326    if cfg!(feature = "use_core") {
327        parse_quote!(::core::cmp::Eq)
328    } else {
329        parse_quote!(::std::cmp::Eq)
330    }
331}
332
333/// Return the path of the `PartialEq` trait, that is `::std::cmp::PartialEq`.
334fn partial_eq_trait_path() -> syn::Path {
335    if cfg!(feature = "use_core") {
336        parse_quote!(::core::cmp::PartialEq)
337    } else {
338        parse_quote!(::std::cmp::PartialEq)
339    }
340}
341
342/// Return the path of the `PartialOrd` trait, that is `::std::cmp::PartialOrd`.
343fn partial_ord_trait_path() -> syn::Path {
344    if cfg!(feature = "use_core") {
345        parse_quote!(::core::cmp::PartialOrd)
346    } else {
347        parse_quote!(::std::cmp::PartialOrd)
348    }
349}
350
351/// Return the path of the `Ord` trait, that is `::std::cmp::Ord`.
352fn ord_trait_path() -> syn::Path {
353    if cfg!(feature = "use_core") {
354        parse_quote!(::core::cmp::Ord)
355    } else {
356        parse_quote!(::std::cmp::Ord)
357    }
358}
359
360/// Return the path of the `Option` trait, that is `::std::option::Option`.
361fn option_path() -> syn::Path {
362    if cfg!(feature = "use_core") {
363        parse_quote!(::core::option::Option)
364    } else {
365        parse_quote!(::std::option::Option)
366    }
367}
368
369/// Return the path of the `Ordering` trait, that is `::std::cmp::Ordering`.
370fn ordering_path() -> syn::Path {
371    if cfg!(feature = "use_core") {
372        parse_quote!(::core::cmp::Ordering)
373    } else {
374        parse_quote!(::std::cmp::Ordering)
375    }
376}
377
378fn maybe_add_copy(
379    input: &ast::Input,
380    where_clause: Option<&syn::WhereClause>,
381    field_filter: impl Fn(&ast::Field) -> bool,
382) -> Option<syn::WhereClause> {
383    if input.attrs.is_packed && !input.body.is_empty() {
384        let mut new_where_clause = where_clause.cloned().unwrap_or_else(|| syn::WhereClause {
385            where_token: parse_quote!(where),
386            predicates: Default::default(),
387        });
388
389        new_where_clause.predicates.extend(
390            input
391                .body
392                .all_fields()
393                .into_iter()
394                .filter(|f| field_filter(f))
395                .map(|f| {
396                    let ty = f.ty;
397
398                    let pred: syn::WherePredicate = parse_quote!(#ty: Copy);
399                    pred
400                }),
401        );
402
403        Some(new_where_clause)
404    } else {
405        None
406    }
407}