serde_with_macros/
apply.rs

1use darling::{ast::NestedMeta, Error as DarlingError, FromMeta};
2use proc_macro::TokenStream;
3use quote::ToTokens as _;
4use syn::{
5    parse::{Parse, ParseStream},
6    punctuated::Punctuated,
7    Attribute, Error, Field, Path, Token, Type, TypeArray, TypeGroup, TypeParen, TypePath, TypePtr,
8    TypeReference, TypeSlice, TypeTuple,
9};
10
11/// Parsed form of a single rule in the `#[apply(...)]` attribute.
12///
13/// This parses tokens in the shape of `Type => Attribute`.
14/// For example, `Option<String> => #[serde(default)]`.
15struct AddAttributesRule {
16    /// A type pattern determining the fields to which the attributes are applied.
17    ty: Type,
18    /// The attributes to apply.
19    ///
20    /// All attributes are appended to the list of existing field attributes.
21    attrs: Vec<Attribute>,
22}
23
24impl Parse for AddAttributesRule {
25    fn parse(input: ParseStream<'_>) -> Result<Self, Error> {
26        let ty: Type = input.parse()?;
27        input.parse::<Token![=>]>()?;
28        let attr = Attribute::parse_outer(input)?;
29        Ok(AddAttributesRule { ty, attrs: attr })
30    }
31}
32
33/// Parsed form of the `#[apply(...)]` attribute.
34///
35/// The `apply` attribute takes a comma separated list of rules in the shape of `Type => Attribute`.
36/// Each rule is stored as a [`AddAttributesRule`].
37struct ApplyInput {
38    metas: Vec<NestedMeta>,
39    rules: Punctuated<AddAttributesRule, Token![,]>,
40}
41
42impl Parse for ApplyInput {
43    fn parse(input: ParseStream<'_>) -> Result<Self, Error> {
44        let mut metas: Vec<NestedMeta> = Vec::new();
45
46        while input.peek2(Token![=]) && !input.peek2(Token![=>]) {
47            let value = NestedMeta::parse(input)?;
48            metas.push(value);
49            if !input.peek(Token![,]) {
50                break;
51            }
52            input.parse::<Token![,]>()?;
53        }
54
55        let rules: Punctuated<AddAttributesRule, Token![,]> =
56            input.parse_terminated(AddAttributesRule::parse, Token![,])?;
57        Ok(Self { metas, rules })
58    }
59}
60
61pub fn apply(args: TokenStream, input: TokenStream) -> TokenStream {
62    let args = syn::parse_macro_input!(args as ApplyInput);
63
64    #[derive(FromMeta)]
65    struct SerdeContainerOptions {
66        #[darling(rename = "crate")]
67        alt_crate_path: Option<Path>,
68    }
69
70    let container_options = match SerdeContainerOptions::from_list(&args.metas) {
71        Ok(v) => v,
72        Err(e) => {
73            return TokenStream::from(e.write_errors());
74        }
75    };
76    let serde_with_crate_path = container_options
77        .alt_crate_path
78        .unwrap_or_else(|| syn::parse_quote!(::serde_with));
79
80    let res = match super::apply_function_to_struct_and_enum_fields_darling(
81        input,
82        &serde_with_crate_path,
83        &prepare_apply_attribute_to_field(args),
84    ) {
85        Ok(res) => res,
86        Err(err) => err.write_errors(),
87    };
88    TokenStream::from(res)
89}
90
91/// Create a function compatible with [`super::apply_function_to_struct_and_enum_fields`] based on [`ApplyInput`].
92///
93/// A single [`ApplyInput`] can apply to multiple field types.
94/// To account for this a new function must be created to stay compatible with the function signature or [`super::apply_function_to_struct_and_enum_fields`].
95fn prepare_apply_attribute_to_field(
96    input: ApplyInput,
97) -> impl Fn(&mut Field) -> Result<(), DarlingError> {
98    move |field: &mut Field| {
99        let has_skip_attr = super::field_has_attribute(field, "serde_with", "skip_apply");
100        if has_skip_attr {
101            return Ok(());
102        }
103
104        for matcher in input.rules.iter() {
105            if ty_pattern_matches_ty(&matcher.ty, &field.ty) {
106                field.attrs.extend(matcher.attrs.clone());
107            }
108        }
109        Ok(())
110    }
111}
112
113fn ty_pattern_matches_ty(ty_pattern: &Type, ty: &Type) -> bool {
114    match (ty_pattern, ty) {
115        // Groups are invisible groupings which can for example come from macro_rules expansion.
116        // This can lead to a mismatch where the `ty` is "Group { Option<String> }" and the `ty_pattern` is "Option<String>".
117        // To account for this we unwrap the group and compare the inner types.
118        (
119            Type::Group(TypeGroup {
120                elem: ty_pattern, ..
121            }),
122            ty,
123        ) => ty_pattern_matches_ty(ty_pattern, ty),
124        (ty_pattern, Type::Group(TypeGroup { elem: ty, .. })) => {
125            ty_pattern_matches_ty(ty_pattern, ty)
126        }
127
128        // Processing of the other types
129        (
130            Type::Array(TypeArray {
131                elem: ty_pattern,
132                len: len_pattern,
133                ..
134            }),
135            Type::Array(TypeArray { elem: ty, len, .. }),
136        ) => {
137            let ty_match = ty_pattern_matches_ty(ty_pattern, ty);
138            let len_match = len_pattern == len || len_pattern.to_token_stream().to_string() == "_";
139            ty_match && len_match
140        }
141        (Type::BareFn(ty_pattern), Type::BareFn(ty)) => ty_pattern == ty,
142        (Type::ImplTrait(ty_pattern), Type::ImplTrait(ty)) => ty_pattern == ty,
143        (Type::Infer(_), _) => true,
144        (Type::Macro(ty_pattern), Type::Macro(ty)) => ty_pattern == ty,
145        (Type::Never(_), Type::Never(_)) => true,
146        (
147            Type::Paren(TypeParen {
148                elem: ty_pattern, ..
149            }),
150            Type::Paren(TypeParen { elem: ty, .. }),
151        ) => ty_pattern_matches_ty(ty_pattern, ty),
152        (
153            Type::Path(TypePath {
154                qself: qself_pattern,
155                path: path_pattern,
156            }),
157            Type::Path(TypePath { qself, path }),
158        ) => {
159            /// Compare two paths for relaxed equality.
160            ///
161            /// Two paths match if they are equal except for the path arguments.
162            /// Path arguments are generics on types or functions.
163            /// If the pattern has no argument, it can match with everything.
164            /// If the pattern does have an argument, the other side must be equal.
165            fn path_pattern_matches_path(path_pattern: &Path, path: &Path) -> bool {
166                if path_pattern.leading_colon != path.leading_colon
167                    || path_pattern.segments.len() != path.segments.len()
168                {
169                    return false;
170                }
171                // Both parts are equal length
172                std::iter::zip(&path_pattern.segments, &path.segments).all(
173                    |(path_pattern_segment, path_segment)| {
174                        let ident_equal = path_pattern_segment.ident == path_segment.ident;
175                        let args_match =
176                            match (&path_pattern_segment.arguments, &path_segment.arguments) {
177                                (syn::PathArguments::None, _) => true,
178                                (
179                                    syn::PathArguments::AngleBracketed(
180                                        syn::AngleBracketedGenericArguments {
181                                            args: args_pattern,
182                                            ..
183                                        },
184                                    ),
185                                    syn::PathArguments::AngleBracketed(
186                                        syn::AngleBracketedGenericArguments { args, .. },
187                                    ),
188                                ) => {
189                                    args_pattern.len() == args.len()
190                                        && std::iter::zip(args_pattern, args).all(|(a, b)| {
191                                            match (a, b) {
192                                                (
193                                                    syn::GenericArgument::Type(ty_pattern),
194                                                    syn::GenericArgument::Type(ty),
195                                                ) => ty_pattern_matches_ty(ty_pattern, ty),
196                                                (a, b) => a == b,
197                                            }
198                                        })
199                                }
200                                (args_pattern, args) => args_pattern == args,
201                            };
202                        ident_equal && args_match
203                    },
204                )
205            }
206            qself_pattern == qself && path_pattern_matches_path(path_pattern, path)
207        }
208        (
209            Type::Ptr(TypePtr {
210                const_token: const_token_pattern,
211                mutability: mutability_pattern,
212                elem: ty_pattern,
213                ..
214            }),
215            Type::Ptr(TypePtr {
216                const_token,
217                mutability,
218                elem: ty,
219                ..
220            }),
221        ) => {
222            const_token_pattern == const_token
223                && mutability_pattern == mutability
224                && ty_pattern_matches_ty(ty_pattern, ty)
225        }
226        (
227            Type::Reference(TypeReference {
228                lifetime: lifetime_pattern,
229                elem: ty_pattern,
230                ..
231            }),
232            Type::Reference(TypeReference {
233                lifetime, elem: ty, ..
234            }),
235        ) => {
236            (lifetime_pattern.is_none() || lifetime_pattern == lifetime)
237                && ty_pattern_matches_ty(ty_pattern, ty)
238        }
239        (
240            Type::Slice(TypeSlice {
241                elem: ty_pattern, ..
242            }),
243            Type::Slice(TypeSlice { elem: ty, .. }),
244        ) => ty_pattern_matches_ty(ty_pattern, ty),
245        (Type::TraitObject(ty_pattern), Type::TraitObject(ty)) => ty_pattern == ty,
246        (
247            Type::Tuple(TypeTuple {
248                elems: ty_pattern, ..
249            }),
250            Type::Tuple(TypeTuple { elems: ty, .. }),
251        ) => {
252            ty_pattern.len() == ty.len()
253                && std::iter::zip(ty_pattern, ty)
254                    .all(|(ty_pattern, ty)| ty_pattern_matches_ty(ty_pattern, ty))
255        }
256        (Type::Verbatim(_), Type::Verbatim(_)) => false,
257        _ => false,
258    }
259}
260
261#[cfg(test)]
262mod test {
263    use super::*;
264
265    #[track_caller]
266    fn matches(ty_pattern: &str, ty: &str) -> bool {
267        let ty_pattern = syn::parse_str(ty_pattern).unwrap();
268        let ty = syn::parse_str(ty).unwrap();
269        ty_pattern_matches_ty(&ty_pattern, &ty)
270    }
271
272    #[test]
273    fn test_ty_generic() {
274        assert!(matches("Option<u8>", "Option<u8>"));
275        assert!(matches("Option", "Option<u8>"));
276        assert!(!matches("Option<u8>", "Option<String>"));
277
278        assert!(matches("BTreeMap<u8, u8>", "BTreeMap<u8, u8>"));
279        assert!(matches("BTreeMap", "BTreeMap<u8, u8>"));
280        assert!(!matches("BTreeMap<String, String>", "BTreeMap<u8, u8>"));
281        assert!(matches("BTreeMap<_, _>", "BTreeMap<u8, u8>"));
282        assert!(matches("BTreeMap<_, u8>", "BTreeMap<u8, u8>"));
283        assert!(!matches("BTreeMap<String, _>", "BTreeMap<u8, u8>"));
284    }
285
286    #[test]
287    fn test_array() {
288        assert!(matches("[u8; 1]", "[u8; 1]"));
289        assert!(matches("[_; 1]", "[u8; 1]"));
290        assert!(matches("[u8; _]", "[u8; 1]"));
291        assert!(matches("[u8; _]", "[u8; N]"));
292
293        assert!(!matches("[u8; 1]", "[u8; 2]"));
294        assert!(!matches("[u8; 1]", "[u8; _]"));
295        assert!(!matches("[u8; 1]", "[String; 1]"));
296    }
297
298    #[test]
299    fn test_reference() {
300        assert!(matches("&str", "&str"));
301        assert!(matches("&mut str", "&str"));
302        assert!(matches("&str", "&mut str"));
303        assert!(matches("&str", "&'a str"));
304        assert!(matches("&str", "&'static str"));
305        assert!(matches("&str", "&'static mut str"));
306
307        assert!(matches("&'a str", "&'a str"));
308        assert!(matches("&'a mut str", "&'a str"));
309
310        assert!(!matches("&'b str", "&'a str"));
311    }
312}