async_trait/
expand.rs

1use crate::bound::{has_bound, InferredBound, Supertraits};
2use crate::lifetime::{AddLifetimeToImplTrait, CollectLifetimes};
3use crate::parse::Item;
4use crate::receiver::{has_self_in_block, has_self_in_sig, mut_pat, ReplaceSelf};
5use crate::verbatim::VerbatimFn;
6use proc_macro2::{Span, TokenStream};
7use quote::{format_ident, quote, quote_spanned, ToTokens};
8use std::collections::BTreeSet as Set;
9use std::mem;
10use syn::punctuated::Punctuated;
11use syn::visit_mut::{self, VisitMut};
12use syn::{
13    parse_quote, parse_quote_spanned, Attribute, Block, FnArg, GenericArgument, GenericParam,
14    Generics, Ident, ImplItem, Lifetime, LifetimeParam, Pat, PatIdent, PathArguments, Receiver,
15    ReturnType, Signature, Token, TraitItem, Type, TypeInfer, TypePath, WhereClause,
16};
17
18impl ToTokens for Item {
19    fn to_tokens(&self, tokens: &mut TokenStream) {
20        match self {
21            Item::Trait(item) => item.to_tokens(tokens),
22            Item::Impl(item) => item.to_tokens(tokens),
23        }
24    }
25}
26
27#[derive(Clone, Copy)]
28enum Context<'a> {
29    Trait {
30        generics: &'a Generics,
31        supertraits: &'a Supertraits,
32    },
33    Impl {
34        impl_generics: &'a Generics,
35        associated_type_impl_traits: &'a Set<Ident>,
36    },
37}
38
39impl Context<'_> {
40    fn lifetimes<'a>(&'a self, used: &'a [Lifetime]) -> impl Iterator<Item = &'a LifetimeParam> {
41        let generics = match self {
42            Context::Trait { generics, .. } => generics,
43            Context::Impl { impl_generics, .. } => impl_generics,
44        };
45        generics.params.iter().filter_map(move |param| {
46            if let GenericParam::Lifetime(param) = param {
47                if used.contains(&param.lifetime) {
48                    return Some(param);
49                }
50            }
51            None
52        })
53    }
54}
55
56pub fn expand(input: &mut Item, is_local: bool) {
57    match input {
58        Item::Trait(input) => {
59            let context = Context::Trait {
60                generics: &input.generics,
61                supertraits: &input.supertraits,
62            };
63            for inner in &mut input.items {
64                if let TraitItem::Fn(method) = inner {
65                    let sig = &mut method.sig;
66                    if sig.asyncness.is_some() {
67                        let block = &mut method.default;
68                        let mut has_self = has_self_in_sig(sig);
69                        method.attrs.push(parse_quote!(#[must_use]));
70                        if let Some(block) = block {
71                            has_self |= has_self_in_block(block);
72                            transform_block(context, sig, block);
73                            method.attrs.push(lint_suppress_with_body());
74                        } else {
75                            method.attrs.push(lint_suppress_without_body());
76                        }
77                        let has_default = method.default.is_some();
78                        transform_sig(context, sig, has_self, has_default, is_local);
79                    }
80                }
81            }
82        }
83        Item::Impl(input) => {
84            let mut associated_type_impl_traits = Set::new();
85            for inner in &input.items {
86                if let ImplItem::Type(assoc) = inner {
87                    if let Type::ImplTrait(_) = assoc.ty {
88                        associated_type_impl_traits.insert(assoc.ident.clone());
89                    }
90                }
91            }
92
93            let context = Context::Impl {
94                impl_generics: &input.generics,
95                associated_type_impl_traits: &associated_type_impl_traits,
96            };
97            for inner in &mut input.items {
98                match inner {
99                    ImplItem::Fn(method) if method.sig.asyncness.is_some() => {
100                        let sig = &mut method.sig;
101                        let block = &mut method.block;
102                        let has_self = has_self_in_sig(sig);
103                        transform_block(context, sig, block);
104                        transform_sig(context, sig, has_self, false, is_local);
105                        method.attrs.push(lint_suppress_with_body());
106                    }
107                    ImplItem::Verbatim(tokens) => {
108                        let mut method = match syn::parse2::<VerbatimFn>(tokens.clone()) {
109                            Ok(method) if method.sig.asyncness.is_some() => method,
110                            _ => continue,
111                        };
112                        let sig = &mut method.sig;
113                        let has_self = has_self_in_sig(sig);
114                        transform_sig(context, sig, has_self, false, is_local);
115                        method.attrs.push(lint_suppress_with_body());
116                        *tokens = quote!(#method);
117                    }
118                    _ => {}
119                }
120            }
121        }
122    }
123}
124
125fn lint_suppress_with_body() -> Attribute {
126    parse_quote! {
127        #[allow(
128            elided_named_lifetimes,
129            clippy::async_yields_async,
130            clippy::diverging_sub_expression,
131            clippy::let_unit_value,
132            clippy::needless_arbitrary_self_type,
133            clippy::no_effect_underscore_binding,
134            clippy::shadow_same,
135            clippy::type_complexity,
136            clippy::type_repetition_in_bounds,
137            clippy::used_underscore_binding
138        )]
139    }
140}
141
142fn lint_suppress_without_body() -> Attribute {
143    parse_quote! {
144        #[allow(
145            elided_named_lifetimes,
146            clippy::type_complexity,
147            clippy::type_repetition_in_bounds
148        )]
149    }
150}
151
152// Input:
153//     async fn f<T>(&self, x: &T) -> Ret;
154//
155// Output:
156//     fn f<'life0, 'life1, 'async_trait, T>(
157//         &'life0 self,
158//         x: &'life1 T,
159//     ) -> Pin<Box<dyn Future<Output = Ret> + Send + 'async_trait>>
160//     where
161//         'life0: 'async_trait,
162//         'life1: 'async_trait,
163//         T: 'async_trait,
164//         Self: Sync + 'async_trait;
165fn transform_sig(
166    context: Context,
167    sig: &mut Signature,
168    has_self: bool,
169    has_default: bool,
170    is_local: bool,
171) {
172    sig.fn_token.span = sig.asyncness.take().unwrap().span;
173
174    let (ret_arrow, ret) = match &sig.output {
175        ReturnType::Default => (quote!(->), quote!(())),
176        ReturnType::Type(arrow, ret) => (quote!(#arrow), quote!(#ret)),
177    };
178
179    let mut lifetimes = CollectLifetimes::new();
180    for arg in &mut sig.inputs {
181        match arg {
182            FnArg::Receiver(arg) => lifetimes.visit_receiver_mut(arg),
183            FnArg::Typed(arg) => lifetimes.visit_type_mut(&mut arg.ty),
184        }
185    }
186
187    for param in &mut sig.generics.params {
188        match param {
189            GenericParam::Type(param) => {
190                let param_name = &param.ident;
191                let span = match param.colon_token.take() {
192                    Some(colon_token) => colon_token.span,
193                    None => param_name.span(),
194                };
195                if param.attrs.is_empty() {
196                    let bounds = mem::take(&mut param.bounds);
197                    where_clause_or_default(&mut sig.generics.where_clause)
198                        .predicates
199                        .push(parse_quote_spanned!(span=> #param_name: 'async_trait + #bounds));
200                } else {
201                    param.bounds.push(parse_quote!('async_trait));
202                }
203            }
204            GenericParam::Lifetime(param) => {
205                let param_name = &param.lifetime;
206                let span = match param.colon_token.take() {
207                    Some(colon_token) => colon_token.span,
208                    None => param_name.span(),
209                };
210                if param.attrs.is_empty() {
211                    let bounds = mem::take(&mut param.bounds);
212                    where_clause_or_default(&mut sig.generics.where_clause)
213                        .predicates
214                        .push(parse_quote_spanned!(span=> #param: 'async_trait + #bounds));
215                } else {
216                    param.bounds.push(parse_quote!('async_trait));
217                }
218            }
219            GenericParam::Const(_) => {}
220        }
221    }
222
223    for param in context.lifetimes(&lifetimes.explicit) {
224        let param = &param.lifetime;
225        let span = param.span();
226        where_clause_or_default(&mut sig.generics.where_clause)
227            .predicates
228            .push(parse_quote_spanned!(span=> #param: 'async_trait));
229    }
230
231    if sig.generics.lt_token.is_none() {
232        sig.generics.lt_token = Some(Token![<](sig.ident.span()));
233    }
234    if sig.generics.gt_token.is_none() {
235        sig.generics.gt_token = Some(Token![>](sig.paren_token.span.join()));
236    }
237
238    for elided in lifetimes.elided {
239        sig.generics.params.push(parse_quote!(#elided));
240        where_clause_or_default(&mut sig.generics.where_clause)
241            .predicates
242            .push(parse_quote_spanned!(elided.span()=> #elided: 'async_trait));
243    }
244
245    sig.generics.params.push(parse_quote!('async_trait));
246
247    if has_self {
248        let bounds: &[InferredBound] = if is_local {
249            &[]
250        } else if let Some(receiver) = sig.receiver() {
251            match receiver.ty.as_ref() {
252                // self: &Self
253                Type::Reference(ty) if ty.mutability.is_none() => &[InferredBound::Sync],
254                // self: Arc<Self>
255                Type::Path(ty)
256                    if {
257                        let segment = ty.path.segments.last().unwrap();
258                        segment.ident == "Arc"
259                            && match &segment.arguments {
260                                PathArguments::AngleBracketed(arguments) => {
261                                    arguments.args.len() == 1
262                                        && match &arguments.args[0] {
263                                            GenericArgument::Type(Type::Path(arg)) => {
264                                                arg.path.is_ident("Self")
265                                            }
266                                            _ => false,
267                                        }
268                                }
269                                _ => false,
270                            }
271                    } =>
272                {
273                    &[InferredBound::Sync, InferredBound::Send]
274                }
275                _ => &[InferredBound::Send],
276            }
277        } else {
278            &[InferredBound::Send]
279        };
280
281        let bounds = bounds.iter().filter(|bound| match context {
282            Context::Trait { supertraits, .. } => has_default && !has_bound(supertraits, bound),
283            Context::Impl { .. } => false,
284        });
285
286        where_clause_or_default(&mut sig.generics.where_clause)
287            .predicates
288            .push(parse_quote! {
289                Self: #(#bounds +)* 'async_trait
290            });
291    }
292
293    for (i, arg) in sig.inputs.iter_mut().enumerate() {
294        match arg {
295            FnArg::Receiver(receiver) => {
296                if receiver.reference.is_none() {
297                    receiver.mutability = None;
298                }
299            }
300            FnArg::Typed(arg) => {
301                if match *arg.ty {
302                    Type::Reference(_) => false,
303                    _ => true,
304                } {
305                    if let Pat::Ident(pat) = &mut *arg.pat {
306                        pat.by_ref = None;
307                        pat.mutability = None;
308                    } else {
309                        let positional = positional_arg(i, &arg.pat);
310                        let m = mut_pat(&mut arg.pat);
311                        arg.pat = parse_quote!(#m #positional);
312                    }
313                }
314                AddLifetimeToImplTrait.visit_type_mut(&mut arg.ty);
315            }
316        }
317    }
318
319    let bounds = if is_local {
320        quote!('async_trait)
321    } else {
322        quote!(::core::marker::Send + 'async_trait)
323    };
324    sig.output = parse_quote! {
325        #ret_arrow ::core::pin::Pin<Box<
326            dyn ::core::future::Future<Output = #ret> + #bounds
327        >>
328    };
329}
330
331// Input:
332//     async fn f<T>(&self, x: &T, (a, b): (A, B)) -> Ret {
333//         self + x + a + b
334//     }
335//
336// Output:
337//     Box::pin(async move {
338//         let ___ret: Ret = {
339//             let __self = self;
340//             let x = x;
341//             let (a, b) = __arg1;
342//
343//             __self + x + a + b
344//         };
345//
346//         ___ret
347//     })
348fn transform_block(context: Context, sig: &mut Signature, block: &mut Block) {
349    let mut replace_self = false;
350    let decls = sig
351        .inputs
352        .iter()
353        .enumerate()
354        .map(|(i, arg)| match arg {
355            FnArg::Receiver(Receiver {
356                self_token,
357                mutability,
358                ..
359            }) => {
360                replace_self = true;
361                let ident = Ident::new("__self", self_token.span);
362                quote!(let #mutability #ident = #self_token;)
363            }
364            FnArg::Typed(arg) => {
365                // If there is a #[cfg(...)] attribute that selectively enables
366                // the parameter, forward it to the variable.
367                //
368                // This is currently not applied to the `self` parameter.
369                let attrs = arg.attrs.iter().filter(|attr| attr.path().is_ident("cfg"));
370
371                if let Type::Reference(_) = *arg.ty {
372                    quote!()
373                } else if let Pat::Ident(PatIdent {
374                    ident, mutability, ..
375                }) = &*arg.pat
376                {
377                    quote! {
378                        #(#attrs)*
379                        let #mutability #ident = #ident;
380                    }
381                } else {
382                    let pat = &arg.pat;
383                    let ident = positional_arg(i, pat);
384                    if let Pat::Wild(_) = **pat {
385                        quote! {
386                            #(#attrs)*
387                            let #ident = #ident;
388                        }
389                    } else {
390                        quote! {
391                            #(#attrs)*
392                            let #pat = {
393                                let #ident = #ident;
394                                #ident
395                            };
396                        }
397                    }
398                }
399            }
400        })
401        .collect::<Vec<_>>();
402
403    if replace_self {
404        ReplaceSelf.visit_block_mut(block);
405    }
406
407    let stmts = &block.stmts;
408    let let_ret = match &mut sig.output {
409        ReturnType::Default => quote_spanned! {block.brace_token.span=>
410            #(#decls)*
411            let () = { #(#stmts)* };
412        },
413        ReturnType::Type(_, ret) => {
414            if contains_associated_type_impl_trait(context, ret) {
415                if decls.is_empty() {
416                    quote!(#(#stmts)*)
417                } else {
418                    quote!(#(#decls)* { #(#stmts)* })
419                }
420            } else {
421                let mut ret = ret.clone();
422                replace_impl_trait_with_infer(&mut ret);
423                quote! {
424                    if let ::core::option::Option::Some(__ret) = ::core::option::Option::None::<#ret> {
425                        #[allow(unreachable_code)]
426                        return __ret;
427                    }
428                    #(#decls)*
429                    let __ret: #ret = { #(#stmts)* };
430                    #[allow(unreachable_code)]
431                    __ret
432                }
433            }
434        }
435    };
436    let box_pin = quote_spanned!(block.brace_token.span=>
437        Box::pin(async move { #let_ret })
438    );
439    block.stmts = parse_quote!(#box_pin);
440}
441
442fn positional_arg(i: usize, pat: &Pat) -> Ident {
443    let span = syn::spanned::Spanned::span(pat).resolved_at(Span::mixed_site());
444    format_ident!("__arg{}", i, span = span)
445}
446
447fn contains_associated_type_impl_trait(context: Context, ret: &mut Type) -> bool {
448    struct AssociatedTypeImplTraits<'a> {
449        set: &'a Set<Ident>,
450        contains: bool,
451    }
452
453    impl<'a> VisitMut for AssociatedTypeImplTraits<'a> {
454        fn visit_type_path_mut(&mut self, ty: &mut TypePath) {
455            if ty.qself.is_none()
456                && ty.path.segments.len() == 2
457                && ty.path.segments[0].ident == "Self"
458                && self.set.contains(&ty.path.segments[1].ident)
459            {
460                self.contains = true;
461            }
462            visit_mut::visit_type_path_mut(self, ty);
463        }
464    }
465
466    match context {
467        Context::Trait { .. } => false,
468        Context::Impl {
469            associated_type_impl_traits,
470            ..
471        } => {
472            let mut visit = AssociatedTypeImplTraits {
473                set: associated_type_impl_traits,
474                contains: false,
475            };
476            visit.visit_type_mut(ret);
477            visit.contains
478        }
479    }
480}
481
482fn where_clause_or_default(clause: &mut Option<WhereClause>) -> &mut WhereClause {
483    clause.get_or_insert_with(|| WhereClause {
484        where_token: Default::default(),
485        predicates: Punctuated::new(),
486    })
487}
488
489fn replace_impl_trait_with_infer(ty: &mut Type) {
490    struct ReplaceImplTraitWithInfer;
491
492    impl VisitMut for ReplaceImplTraitWithInfer {
493        fn visit_type_mut(&mut self, ty: &mut Type) {
494            if let Type::ImplTrait(impl_trait) = ty {
495                *ty = Type::Infer(TypeInfer {
496                    underscore_token: Token![_](impl_trait.impl_token.span),
497                });
498            }
499            visit_mut::visit_type_mut(self, ty);
500        }
501    }
502
503    ReplaceImplTraitWithInfer.visit_type_mut(ty);
504}