1use crate::bound::{has_bound, InferredBound, Supertraits};
2use crate::lifetime::{AddLifetimeToImplTrait, CollectLifetimes};
3use crate::parse::Item;
4use crate::receiver::{has_self_in_block, has_self_in_sig, mut_pat, ReplaceSelf};
5use crate::verbatim::VerbatimFn;
6use proc_macro2::{Span, TokenStream};
7use quote::{format_ident, quote, quote_spanned, ToTokens};
8use std::collections::BTreeSet as Set;
9use std::mem;
10use syn::punctuated::Punctuated;
11use syn::visit_mut::{self, VisitMut};
12use syn::{
13 parse_quote, parse_quote_spanned, Attribute, Block, FnArg, GenericArgument, GenericParam,
14 Generics, Ident, ImplItem, Lifetime, LifetimeParam, Pat, PatIdent, PathArguments, Receiver,
15 ReturnType, Signature, Token, TraitItem, Type, TypeInfer, TypePath, WhereClause,
16};
17
18impl ToTokens for Item {
19 fn to_tokens(&self, tokens: &mut TokenStream) {
20 match self {
21 Item::Trait(item) => item.to_tokens(tokens),
22 Item::Impl(item) => item.to_tokens(tokens),
23 }
24 }
25}
26
27#[derive(Clone, Copy)]
28enum Context<'a> {
29 Trait {
30 generics: &'a Generics,
31 supertraits: &'a Supertraits,
32 },
33 Impl {
34 impl_generics: &'a Generics,
35 associated_type_impl_traits: &'a Set<Ident>,
36 },
37}
38
39impl Context<'_> {
40 fn lifetimes<'a>(&'a self, used: &'a [Lifetime]) -> impl Iterator<Item = &'a LifetimeParam> {
41 let generics = match self {
42 Context::Trait { generics, .. } => generics,
43 Context::Impl { impl_generics, .. } => impl_generics,
44 };
45 generics.params.iter().filter_map(move |param| {
46 if let GenericParam::Lifetime(param) = param {
47 if used.contains(¶m.lifetime) {
48 return Some(param);
49 }
50 }
51 None
52 })
53 }
54}
55
56pub fn expand(input: &mut Item, is_local: bool) {
57 match input {
58 Item::Trait(input) => {
59 let context = Context::Trait {
60 generics: &input.generics,
61 supertraits: &input.supertraits,
62 };
63 for inner in &mut input.items {
64 if let TraitItem::Fn(method) = inner {
65 let sig = &mut method.sig;
66 if sig.asyncness.is_some() {
67 let block = &mut method.default;
68 let mut has_self = has_self_in_sig(sig);
69 method.attrs.push(parse_quote!(#[must_use]));
70 if let Some(block) = block {
71 has_self |= has_self_in_block(block);
72 transform_block(context, sig, block);
73 method.attrs.push(lint_suppress_with_body());
74 } else {
75 method.attrs.push(lint_suppress_without_body());
76 }
77 let has_default = method.default.is_some();
78 transform_sig(context, sig, has_self, has_default, is_local);
79 }
80 }
81 }
82 }
83 Item::Impl(input) => {
84 let mut associated_type_impl_traits = Set::new();
85 for inner in &input.items {
86 if let ImplItem::Type(assoc) = inner {
87 if let Type::ImplTrait(_) = assoc.ty {
88 associated_type_impl_traits.insert(assoc.ident.clone());
89 }
90 }
91 }
92
93 let context = Context::Impl {
94 impl_generics: &input.generics,
95 associated_type_impl_traits: &associated_type_impl_traits,
96 };
97 for inner in &mut input.items {
98 match inner {
99 ImplItem::Fn(method) if method.sig.asyncness.is_some() => {
100 let sig = &mut method.sig;
101 let block = &mut method.block;
102 let has_self = has_self_in_sig(sig);
103 transform_block(context, sig, block);
104 transform_sig(context, sig, has_self, false, is_local);
105 method.attrs.push(lint_suppress_with_body());
106 }
107 ImplItem::Verbatim(tokens) => {
108 let mut method = match syn::parse2::<VerbatimFn>(tokens.clone()) {
109 Ok(method) if method.sig.asyncness.is_some() => method,
110 _ => continue,
111 };
112 let sig = &mut method.sig;
113 let has_self = has_self_in_sig(sig);
114 transform_sig(context, sig, has_self, false, is_local);
115 method.attrs.push(lint_suppress_with_body());
116 *tokens = quote!(#method);
117 }
118 _ => {}
119 }
120 }
121 }
122 }
123}
124
125fn lint_suppress_with_body() -> Attribute {
126 parse_quote! {
127 #[allow(
128 elided_named_lifetimes,
129 clippy::async_yields_async,
130 clippy::diverging_sub_expression,
131 clippy::let_unit_value,
132 clippy::needless_arbitrary_self_type,
133 clippy::no_effect_underscore_binding,
134 clippy::shadow_same,
135 clippy::type_complexity,
136 clippy::type_repetition_in_bounds,
137 clippy::used_underscore_binding
138 )]
139 }
140}
141
142fn lint_suppress_without_body() -> Attribute {
143 parse_quote! {
144 #[allow(
145 elided_named_lifetimes,
146 clippy::type_complexity,
147 clippy::type_repetition_in_bounds
148 )]
149 }
150}
151
152fn transform_sig(
166 context: Context,
167 sig: &mut Signature,
168 has_self: bool,
169 has_default: bool,
170 is_local: bool,
171) {
172 sig.fn_token.span = sig.asyncness.take().unwrap().span;
173
174 let (ret_arrow, ret) = match &sig.output {
175 ReturnType::Default => (quote!(->), quote!(())),
176 ReturnType::Type(arrow, ret) => (quote!(#arrow), quote!(#ret)),
177 };
178
179 let mut lifetimes = CollectLifetimes::new();
180 for arg in &mut sig.inputs {
181 match arg {
182 FnArg::Receiver(arg) => lifetimes.visit_receiver_mut(arg),
183 FnArg::Typed(arg) => lifetimes.visit_type_mut(&mut arg.ty),
184 }
185 }
186
187 for param in &mut sig.generics.params {
188 match param {
189 GenericParam::Type(param) => {
190 let param_name = ¶m.ident;
191 let span = match param.colon_token.take() {
192 Some(colon_token) => colon_token.span,
193 None => param_name.span(),
194 };
195 if param.attrs.is_empty() {
196 let bounds = mem::take(&mut param.bounds);
197 where_clause_or_default(&mut sig.generics.where_clause)
198 .predicates
199 .push(parse_quote_spanned!(span=> #param_name: 'async_trait + #bounds));
200 } else {
201 param.bounds.push(parse_quote!('async_trait));
202 }
203 }
204 GenericParam::Lifetime(param) => {
205 let param_name = ¶m.lifetime;
206 let span = match param.colon_token.take() {
207 Some(colon_token) => colon_token.span,
208 None => param_name.span(),
209 };
210 if param.attrs.is_empty() {
211 let bounds = mem::take(&mut param.bounds);
212 where_clause_or_default(&mut sig.generics.where_clause)
213 .predicates
214 .push(parse_quote_spanned!(span=> #param: 'async_trait + #bounds));
215 } else {
216 param.bounds.push(parse_quote!('async_trait));
217 }
218 }
219 GenericParam::Const(_) => {}
220 }
221 }
222
223 for param in context.lifetimes(&lifetimes.explicit) {
224 let param = ¶m.lifetime;
225 let span = param.span();
226 where_clause_or_default(&mut sig.generics.where_clause)
227 .predicates
228 .push(parse_quote_spanned!(span=> #param: 'async_trait));
229 }
230
231 if sig.generics.lt_token.is_none() {
232 sig.generics.lt_token = Some(Token));
233 }
234 if sig.generics.gt_token.is_none() {
235 sig.generics.gt_token = Some(Token));
236 }
237
238 for elided in lifetimes.elided {
239 sig.generics.params.push(parse_quote!(#elided));
240 where_clause_or_default(&mut sig.generics.where_clause)
241 .predicates
242 .push(parse_quote_spanned!(elided.span()=> #elided: 'async_trait));
243 }
244
245 sig.generics.params.push(parse_quote!('async_trait));
246
247 if has_self {
248 let bounds: &[InferredBound] = if is_local {
249 &[]
250 } else if let Some(receiver) = sig.receiver() {
251 match receiver.ty.as_ref() {
252 Type::Reference(ty) if ty.mutability.is_none() => &[InferredBound::Sync],
254 Type::Path(ty)
256 if {
257 let segment = ty.path.segments.last().unwrap();
258 segment.ident == "Arc"
259 && match &segment.arguments {
260 PathArguments::AngleBracketed(arguments) => {
261 arguments.args.len() == 1
262 && match &arguments.args[0] {
263 GenericArgument::Type(Type::Path(arg)) => {
264 arg.path.is_ident("Self")
265 }
266 _ => false,
267 }
268 }
269 _ => false,
270 }
271 } =>
272 {
273 &[InferredBound::Sync, InferredBound::Send]
274 }
275 _ => &[InferredBound::Send],
276 }
277 } else {
278 &[InferredBound::Send]
279 };
280
281 let bounds = bounds.iter().filter(|bound| match context {
282 Context::Trait { supertraits, .. } => has_default && !has_bound(supertraits, bound),
283 Context::Impl { .. } => false,
284 });
285
286 where_clause_or_default(&mut sig.generics.where_clause)
287 .predicates
288 .push(parse_quote! {
289 Self: #(#bounds +)* 'async_trait
290 });
291 }
292
293 for (i, arg) in sig.inputs.iter_mut().enumerate() {
294 match arg {
295 FnArg::Receiver(receiver) => {
296 if receiver.reference.is_none() {
297 receiver.mutability = None;
298 }
299 }
300 FnArg::Typed(arg) => {
301 if match *arg.ty {
302 Type::Reference(_) => false,
303 _ => true,
304 } {
305 if let Pat::Ident(pat) = &mut *arg.pat {
306 pat.by_ref = None;
307 pat.mutability = None;
308 } else {
309 let positional = positional_arg(i, &arg.pat);
310 let m = mut_pat(&mut arg.pat);
311 arg.pat = parse_quote!(#m #positional);
312 }
313 }
314 AddLifetimeToImplTrait.visit_type_mut(&mut arg.ty);
315 }
316 }
317 }
318
319 let bounds = if is_local {
320 quote!('async_trait)
321 } else {
322 quote!(::core::marker::Send + 'async_trait)
323 };
324 sig.output = parse_quote! {
325 #ret_arrow ::core::pin::Pin<Box<
326 dyn ::core::future::Future<Output = #ret> + #bounds
327 >>
328 };
329}
330
331fn transform_block(context: Context, sig: &mut Signature, block: &mut Block) {
349 let mut replace_self = false;
350 let decls = sig
351 .inputs
352 .iter()
353 .enumerate()
354 .map(|(i, arg)| match arg {
355 FnArg::Receiver(Receiver {
356 self_token,
357 mutability,
358 ..
359 }) => {
360 replace_self = true;
361 let ident = Ident::new("__self", self_token.span);
362 quote!(let #mutability #ident = #self_token;)
363 }
364 FnArg::Typed(arg) => {
365 let attrs = arg.attrs.iter().filter(|attr| attr.path().is_ident("cfg"));
370
371 if let Type::Reference(_) = *arg.ty {
372 quote!()
373 } else if let Pat::Ident(PatIdent {
374 ident, mutability, ..
375 }) = &*arg.pat
376 {
377 quote! {
378 #(#attrs)*
379 let #mutability #ident = #ident;
380 }
381 } else {
382 let pat = &arg.pat;
383 let ident = positional_arg(i, pat);
384 if let Pat::Wild(_) = **pat {
385 quote! {
386 #(#attrs)*
387 let #ident = #ident;
388 }
389 } else {
390 quote! {
391 #(#attrs)*
392 let #pat = {
393 let #ident = #ident;
394 #ident
395 };
396 }
397 }
398 }
399 }
400 })
401 .collect::<Vec<_>>();
402
403 if replace_self {
404 ReplaceSelf.visit_block_mut(block);
405 }
406
407 let stmts = &block.stmts;
408 let let_ret = match &mut sig.output {
409 ReturnType::Default => quote_spanned! {block.brace_token.span=>
410 #(#decls)*
411 let () = { #(#stmts)* };
412 },
413 ReturnType::Type(_, ret) => {
414 if contains_associated_type_impl_trait(context, ret) {
415 if decls.is_empty() {
416 quote!(#(#stmts)*)
417 } else {
418 quote!(#(#decls)* { #(#stmts)* })
419 }
420 } else {
421 let mut ret = ret.clone();
422 replace_impl_trait_with_infer(&mut ret);
423 quote! {
424 if let ::core::option::Option::Some(__ret) = ::core::option::Option::None::<#ret> {
425 #[allow(unreachable_code)]
426 return __ret;
427 }
428 #(#decls)*
429 let __ret: #ret = { #(#stmts)* };
430 #[allow(unreachable_code)]
431 __ret
432 }
433 }
434 }
435 };
436 let box_pin = quote_spanned!(block.brace_token.span=>
437 Box::pin(async move { #let_ret })
438 );
439 block.stmts = parse_quote!(#box_pin);
440}
441
442fn positional_arg(i: usize, pat: &Pat) -> Ident {
443 let span = syn::spanned::Spanned::span(pat).resolved_at(Span::mixed_site());
444 format_ident!("__arg{}", i, span = span)
445}
446
447fn contains_associated_type_impl_trait(context: Context, ret: &mut Type) -> bool {
448 struct AssociatedTypeImplTraits<'a> {
449 set: &'a Set<Ident>,
450 contains: bool,
451 }
452
453 impl<'a> VisitMut for AssociatedTypeImplTraits<'a> {
454 fn visit_type_path_mut(&mut self, ty: &mut TypePath) {
455 if ty.qself.is_none()
456 && ty.path.segments.len() == 2
457 && ty.path.segments[0].ident == "Self"
458 && self.set.contains(&ty.path.segments[1].ident)
459 {
460 self.contains = true;
461 }
462 visit_mut::visit_type_path_mut(self, ty);
463 }
464 }
465
466 match context {
467 Context::Trait { .. } => false,
468 Context::Impl {
469 associated_type_impl_traits,
470 ..
471 } => {
472 let mut visit = AssociatedTypeImplTraits {
473 set: associated_type_impl_traits,
474 contains: false,
475 };
476 visit.visit_type_mut(ret);
477 visit.contains
478 }
479 }
480}
481
482fn where_clause_or_default(clause: &mut Option<WhereClause>) -> &mut WhereClause {
483 clause.get_or_insert_with(|| WhereClause {
484 where_token: Default::default(),
485 predicates: Punctuated::new(),
486 })
487}
488
489fn replace_impl_trait_with_infer(ty: &mut Type) {
490 struct ReplaceImplTraitWithInfer;
491
492 impl VisitMut for ReplaceImplTraitWithInfer {
493 fn visit_type_mut(&mut self, ty: &mut Type) {
494 if let Type::ImplTrait(impl_trait) = ty {
495 *ty = Type::Infer(TypeInfer {
496 underscore_token: Token,
497 });
498 }
499 visit_mut::visit_type_mut(self, ty);
500 }
501 }
502
503 ReplaceImplTraitWithInfer.visit_type_mut(ty);
504}