derivative/
bound.rs

1/* This file incorporates work covered by the following copyright and
2 * permission notice:
3 *   Copyright 2016 The serde Developers. See
4 *   https://github.com/serde-rs/serde/blob/3f28a9324042950afa80354722aeeee1a55cbfa3/README.md#license.
5 *
6 *   Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
7 *   http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
8 *   <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
9 *   option. This file may not be copied, modified, or distributed
10 *   except according to those terms.
11 */
12
13use ast;
14use attr;
15use std::collections::HashSet;
16use syn::{self, visit, GenericParam};
17
18// use internals::ast::Item;
19// use internals::attr;
20
21/// Remove the default from every type parameter because in the generated `impl`s
22/// they look like associated types: "error: associated type bindings are not
23/// allowed here".
24pub fn without_defaults(generics: &syn::Generics) -> syn::Generics {
25    syn::Generics {
26        params: generics
27            .params
28            .iter()
29            .map(|generic_param| match *generic_param {
30                GenericParam::Type(ref ty_param) => syn::GenericParam::Type(syn::TypeParam {
31                    default: None,
32                    ..ty_param.clone()
33                }),
34                ref param => param.clone(),
35            })
36            .collect(),
37        ..generics.clone()
38    }
39}
40
41pub fn with_where_predicates(
42    generics: &syn::Generics,
43    predicates: &[syn::WherePredicate],
44) -> syn::Generics {
45    let mut cloned = generics.clone();
46    cloned
47        .make_where_clause()
48        .predicates
49        .extend(predicates.iter().cloned());
50    cloned
51}
52
53pub fn with_where_predicates_from_fields<F>(
54    item: &ast::Input,
55    generics: &syn::Generics,
56    from_field: F,
57) -> syn::Generics
58where
59    F: Fn(&attr::Field) -> Option<&[syn::WherePredicate]>,
60{
61    let mut cloned = generics.clone();
62    {
63        let fields = item.body.all_fields();
64        let field_where_predicates = fields
65            .iter()
66            .flat_map(|field| from_field(&field.attrs))
67            .flat_map(|predicates| predicates.to_vec());
68
69        cloned
70            .make_where_clause()
71            .predicates
72            .extend(field_where_predicates);
73    }
74    cloned
75}
76
77/// Puts the given bound on any generic type parameters that are used in fields
78/// for which filter returns true.
79///
80/// For example, the following structure needs the bound `A: Debug, B: Debug`.
81///
82/// ```ignore
83/// struct S<'b, A, B: 'b, C> {
84///     a: A,
85///     b: Option<&'b B>
86///     #[derivative(Debug="ignore")]
87///     c: C,
88/// }
89/// ```
90pub fn with_bound<F>(
91    item: &ast::Input,
92    generics: &syn::Generics,
93    filter: F,
94    bound: &syn::Path,
95) -> syn::Generics
96where
97    F: Fn(&attr::Field) -> bool,
98{
99    #[derive(Debug)]
100    struct FindTyParams {
101        /// Set of all generic type parameters on the current struct (A, B, C in
102        /// the example). Initialized up front.
103        all_ty_params: HashSet<syn::Ident>,
104        /// Set of generic type parameters used in fields for which filter
105        /// returns true (A and B in the example). Filled in as the visitor sees
106        /// them.
107        relevant_ty_params: HashSet<syn::Ident>,
108    }
109    impl<'ast> visit::Visit<'ast> for FindTyParams {
110        fn visit_path(&mut self, path: &'ast syn::Path) {
111            if is_phantom_data(path) {
112                // Hardcoded exception, because `PhantomData<T>` implements
113                // most traits whether or not `T` implements it.
114                return;
115            }
116            if path.leading_colon.is_none() && path.segments.len() == 1 {
117                let id = &path.segments[0].ident;
118                if self.all_ty_params.contains(id) {
119                    self.relevant_ty_params.insert(id.clone());
120                }
121            }
122            visit::visit_path(self, path);
123        }
124    }
125
126    let all_ty_params: HashSet<_> = generics
127        .type_params()
128        .map(|ty_param| ty_param.ident.clone())
129        .collect();
130
131    let relevant_tys = item
132        .body
133        .all_fields()
134        .into_iter()
135        .filter(|field| {
136            if let syn::Type::Path(syn::TypePath { ref path, .. }) = *field.ty {
137                !is_phantom_data(path)
138            } else {
139                true
140            }
141        })
142        .filter(|field| filter(&field.attrs))
143        .map(|field| &field.ty);
144
145    let mut visitor = FindTyParams {
146        all_ty_params,
147        relevant_ty_params: HashSet::new(),
148    };
149    for ty in relevant_tys {
150        visit::visit_type(&mut visitor, ty);
151    }
152
153    let mut cloned = generics.clone();
154    {
155        let relevant_where_predicates = generics
156            .type_params()
157            .map(|ty_param| &ty_param.ident)
158            .filter(|id| visitor.relevant_ty_params.contains(id))
159            .map(|id| -> syn::WherePredicate { parse_quote!( #id : #bound ) });
160
161        cloned
162            .make_where_clause()
163            .predicates
164            .extend(relevant_where_predicates);
165    }
166    cloned
167}
168
169#[allow(clippy::match_like_matches_macro)] // needs rustc 1.42
170fn is_phantom_data(path: &syn::Path) -> bool {
171    match path.segments.last() {
172        Some(path) if path.ident == "PhantomData" => true,
173        _ => false,
174    }
175}