1use darling::{ast::NestedMeta, Error as DarlingError, FromMeta};
2use proc_macro::TokenStream;
3use quote::ToTokens as _;
4use syn::{
5 parse::{Parse, ParseStream},
6 punctuated::Punctuated,
7 Attribute, Error, Field, Path, Token, Type, TypeArray, TypeGroup, TypeParen, TypePath, TypePtr,
8 TypeReference, TypeSlice, TypeTuple,
9};
10
11struct AddAttributesRule {
16 ty: Type,
18 attrs: Vec<Attribute>,
22}
23
24impl Parse for AddAttributesRule {
25 fn parse(input: ParseStream<'_>) -> Result<Self, Error> {
26 let ty: Type = input.parse()?;
27 input.parse::<Token![=>]>()?;
28 let attr = Attribute::parse_outer(input)?;
29 Ok(AddAttributesRule { ty, attrs: attr })
30 }
31}
32
33struct ApplyInput {
38 metas: Vec<NestedMeta>,
39 rules: Punctuated<AddAttributesRule, Token![,]>,
40}
41
42impl Parse for ApplyInput {
43 fn parse(input: ParseStream<'_>) -> Result<Self, Error> {
44 let mut metas: Vec<NestedMeta> = Vec::new();
45
46 while input.peek2(Token![=]) && !input.peek2(Token![=>]) {
47 let value = NestedMeta::parse(input)?;
48 metas.push(value);
49 if !input.peek(Token![,]) {
50 break;
51 }
52 input.parse::<Token![,]>()?;
53 }
54
55 let rules: Punctuated<AddAttributesRule, Token![,]> =
56 input.parse_terminated(AddAttributesRule::parse, Token![,])?;
57 Ok(Self { metas, rules })
58 }
59}
60
61pub fn apply(args: TokenStream, input: TokenStream) -> TokenStream {
62 let args = syn::parse_macro_input!(args as ApplyInput);
63
64 #[derive(FromMeta)]
65 struct SerdeContainerOptions {
66 #[darling(rename = "crate")]
67 alt_crate_path: Option<Path>,
68 }
69
70 let container_options = match SerdeContainerOptions::from_list(&args.metas) {
71 Ok(v) => v,
72 Err(e) => {
73 return TokenStream::from(e.write_errors());
74 }
75 };
76 let serde_with_crate_path = container_options
77 .alt_crate_path
78 .unwrap_or_else(|| syn::parse_quote!(::serde_with));
79
80 let res = match super::apply_function_to_struct_and_enum_fields_darling(
81 input,
82 &serde_with_crate_path,
83 &prepare_apply_attribute_to_field(args),
84 ) {
85 Ok(res) => res,
86 Err(err) => err.write_errors(),
87 };
88 TokenStream::from(res)
89}
90
91fn prepare_apply_attribute_to_field(
96 input: ApplyInput,
97) -> impl Fn(&mut Field) -> Result<(), DarlingError> {
98 move |field: &mut Field| {
99 let has_skip_attr = super::field_has_attribute(field, "serde_with", "skip_apply");
100 if has_skip_attr {
101 return Ok(());
102 }
103
104 for matcher in input.rules.iter() {
105 if ty_pattern_matches_ty(&matcher.ty, &field.ty) {
106 field.attrs.extend(matcher.attrs.clone());
107 }
108 }
109 Ok(())
110 }
111}
112
113fn ty_pattern_matches_ty(ty_pattern: &Type, ty: &Type) -> bool {
114 match (ty_pattern, ty) {
115 (
119 Type::Group(TypeGroup {
120 elem: ty_pattern, ..
121 }),
122 ty,
123 ) => ty_pattern_matches_ty(ty_pattern, ty),
124 (ty_pattern, Type::Group(TypeGroup { elem: ty, .. })) => {
125 ty_pattern_matches_ty(ty_pattern, ty)
126 }
127
128 (
130 Type::Array(TypeArray {
131 elem: ty_pattern,
132 len: len_pattern,
133 ..
134 }),
135 Type::Array(TypeArray { elem: ty, len, .. }),
136 ) => {
137 let ty_match = ty_pattern_matches_ty(ty_pattern, ty);
138 let len_match = len_pattern == len || len_pattern.to_token_stream().to_string() == "_";
139 ty_match && len_match
140 }
141 (Type::BareFn(ty_pattern), Type::BareFn(ty)) => ty_pattern == ty,
142 (Type::ImplTrait(ty_pattern), Type::ImplTrait(ty)) => ty_pattern == ty,
143 (Type::Infer(_), _) => true,
144 (Type::Macro(ty_pattern), Type::Macro(ty)) => ty_pattern == ty,
145 (Type::Never(_), Type::Never(_)) => true,
146 (
147 Type::Paren(TypeParen {
148 elem: ty_pattern, ..
149 }),
150 Type::Paren(TypeParen { elem: ty, .. }),
151 ) => ty_pattern_matches_ty(ty_pattern, ty),
152 (
153 Type::Path(TypePath {
154 qself: qself_pattern,
155 path: path_pattern,
156 }),
157 Type::Path(TypePath { qself, path }),
158 ) => {
159 fn path_pattern_matches_path(path_pattern: &Path, path: &Path) -> bool {
166 if path_pattern.leading_colon != path.leading_colon
167 || path_pattern.segments.len() != path.segments.len()
168 {
169 return false;
170 }
171 std::iter::zip(&path_pattern.segments, &path.segments).all(
173 |(path_pattern_segment, path_segment)| {
174 let ident_equal = path_pattern_segment.ident == path_segment.ident;
175 let args_match =
176 match (&path_pattern_segment.arguments, &path_segment.arguments) {
177 (syn::PathArguments::None, _) => true,
178 (
179 syn::PathArguments::AngleBracketed(
180 syn::AngleBracketedGenericArguments {
181 args: args_pattern,
182 ..
183 },
184 ),
185 syn::PathArguments::AngleBracketed(
186 syn::AngleBracketedGenericArguments { args, .. },
187 ),
188 ) => {
189 args_pattern.len() == args.len()
190 && std::iter::zip(args_pattern, args).all(|(a, b)| {
191 match (a, b) {
192 (
193 syn::GenericArgument::Type(ty_pattern),
194 syn::GenericArgument::Type(ty),
195 ) => ty_pattern_matches_ty(ty_pattern, ty),
196 (a, b) => a == b,
197 }
198 })
199 }
200 (args_pattern, args) => args_pattern == args,
201 };
202 ident_equal && args_match
203 },
204 )
205 }
206 qself_pattern == qself && path_pattern_matches_path(path_pattern, path)
207 }
208 (
209 Type::Ptr(TypePtr {
210 const_token: const_token_pattern,
211 mutability: mutability_pattern,
212 elem: ty_pattern,
213 ..
214 }),
215 Type::Ptr(TypePtr {
216 const_token,
217 mutability,
218 elem: ty,
219 ..
220 }),
221 ) => {
222 const_token_pattern == const_token
223 && mutability_pattern == mutability
224 && ty_pattern_matches_ty(ty_pattern, ty)
225 }
226 (
227 Type::Reference(TypeReference {
228 lifetime: lifetime_pattern,
229 elem: ty_pattern,
230 ..
231 }),
232 Type::Reference(TypeReference {
233 lifetime, elem: ty, ..
234 }),
235 ) => {
236 (lifetime_pattern.is_none() || lifetime_pattern == lifetime)
237 && ty_pattern_matches_ty(ty_pattern, ty)
238 }
239 (
240 Type::Slice(TypeSlice {
241 elem: ty_pattern, ..
242 }),
243 Type::Slice(TypeSlice { elem: ty, .. }),
244 ) => ty_pattern_matches_ty(ty_pattern, ty),
245 (Type::TraitObject(ty_pattern), Type::TraitObject(ty)) => ty_pattern == ty,
246 (
247 Type::Tuple(TypeTuple {
248 elems: ty_pattern, ..
249 }),
250 Type::Tuple(TypeTuple { elems: ty, .. }),
251 ) => {
252 ty_pattern.len() == ty.len()
253 && std::iter::zip(ty_pattern, ty)
254 .all(|(ty_pattern, ty)| ty_pattern_matches_ty(ty_pattern, ty))
255 }
256 (Type::Verbatim(_), Type::Verbatim(_)) => false,
257 _ => false,
258 }
259}
260
261#[cfg(test)]
262mod test {
263 use super::*;
264
265 #[track_caller]
266 fn matches(ty_pattern: &str, ty: &str) -> bool {
267 let ty_pattern = syn::parse_str(ty_pattern).unwrap();
268 let ty = syn::parse_str(ty).unwrap();
269 ty_pattern_matches_ty(&ty_pattern, &ty)
270 }
271
272 #[test]
273 fn test_ty_generic() {
274 assert!(matches("Option<u8>", "Option<u8>"));
275 assert!(matches("Option", "Option<u8>"));
276 assert!(!matches("Option<u8>", "Option<String>"));
277
278 assert!(matches("BTreeMap<u8, u8>", "BTreeMap<u8, u8>"));
279 assert!(matches("BTreeMap", "BTreeMap<u8, u8>"));
280 assert!(!matches("BTreeMap<String, String>", "BTreeMap<u8, u8>"));
281 assert!(matches("BTreeMap<_, _>", "BTreeMap<u8, u8>"));
282 assert!(matches("BTreeMap<_, u8>", "BTreeMap<u8, u8>"));
283 assert!(!matches("BTreeMap<String, _>", "BTreeMap<u8, u8>"));
284 }
285
286 #[test]
287 fn test_array() {
288 assert!(matches("[u8; 1]", "[u8; 1]"));
289 assert!(matches("[_; 1]", "[u8; 1]"));
290 assert!(matches("[u8; _]", "[u8; 1]"));
291 assert!(matches("[u8; _]", "[u8; N]"));
292
293 assert!(!matches("[u8; 1]", "[u8; 2]"));
294 assert!(!matches("[u8; 1]", "[u8; _]"));
295 assert!(!matches("[u8; 1]", "[String; 1]"));
296 }
297
298 #[test]
299 fn test_reference() {
300 assert!(matches("&str", "&str"));
301 assert!(matches("&mut str", "&str"));
302 assert!(matches("&str", "&mut str"));
303 assert!(matches("&str", "&'a str"));
304 assert!(matches("&str", "&'static str"));
305 assert!(matches("&str", "&'static mut str"));
306
307 assert!(matches("&'a str", "&'a str"));
308 assert!(matches("&'a mut str", "&'a str"));
309
310 assert!(!matches("&'b str", "&'a str"));
311 }
312}