async_stream_impl/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::{Group, TokenStream as TokenStream2, TokenTree};
3use quote::quote;
4use syn::parse::{Parse, ParseStream, Parser, Result};
5use syn::visit_mut::VisitMut;
6
7struct Scrub<'a> {
8    /// Whether the stream is a try stream.
9    is_try: bool,
10    /// The unit expression, `()`.
11    unit: Box<syn::Expr>,
12    has_yielded: bool,
13    crate_path: &'a TokenStream2,
14}
15
16fn parse_input(input: TokenStream) -> syn::Result<(TokenStream2, Vec<syn::Stmt>)> {
17    let mut input = TokenStream2::from(input).into_iter();
18    let crate_path = match input.next().unwrap() {
19        TokenTree::Group(group) => group.stream(),
20        _ => panic!(),
21    };
22    let stmts = syn::Block::parse_within.parse2(replace_for_await(input))?;
23    Ok((crate_path, stmts))
24}
25
26impl<'a> Scrub<'a> {
27    fn new(is_try: bool, crate_path: &'a TokenStream2) -> Self {
28        Self {
29            is_try,
30            unit: syn::parse_quote!(()),
31            has_yielded: false,
32            crate_path,
33        }
34    }
35}
36
37struct Partial<T>(T, TokenStream2);
38
39impl<T: Parse> Parse for Partial<T> {
40    fn parse(input: ParseStream) -> Result<Self> {
41        Ok(Partial(input.parse()?, input.parse()?))
42    }
43}
44
45fn visit_token_stream_impl(
46    visitor: &mut Scrub<'_>,
47    tokens: TokenStream2,
48    modified: &mut bool,
49    out: &mut TokenStream2,
50) {
51    use quote::ToTokens;
52    use quote::TokenStreamExt;
53
54    let mut tokens = tokens.into_iter().peekable();
55    while let Some(tt) = tokens.next() {
56        match tt {
57            TokenTree::Ident(i) if i == "yield" => {
58                let stream = std::iter::once(TokenTree::Ident(i)).chain(tokens).collect();
59                match syn::parse2(stream) {
60                    Ok(Partial(yield_expr, rest)) => {
61                        let mut expr = syn::Expr::Yield(yield_expr);
62                        visitor.visit_expr_mut(&mut expr);
63                        expr.to_tokens(out);
64                        *modified = true;
65                        tokens = rest.into_iter().peekable();
66                    }
67                    Err(e) => {
68                        out.append_all(e.to_compile_error().into_iter());
69                        *modified = true;
70                        return;
71                    }
72                }
73            }
74            TokenTree::Ident(i) if i == "stream" || i == "try_stream" => {
75                out.append(TokenTree::Ident(i));
76                match tokens.peek() {
77                    Some(TokenTree::Punct(p)) if p.as_char() == '!' => {
78                        out.extend(tokens.next()); // !
79                        if let Some(TokenTree::Group(_)) = tokens.peek() {
80                            out.extend(tokens.next()); // { .. } or [ .. ] or ( .. )
81                        }
82                    }
83                    _ => {}
84                }
85            }
86            TokenTree::Group(group) => {
87                let mut content = group.stream();
88                *modified |= visitor.visit_token_stream(&mut content);
89                let mut new = Group::new(group.delimiter(), content);
90                new.set_span(group.span());
91                out.append(new);
92            }
93            other => out.append(other),
94        }
95    }
96}
97
98impl Scrub<'_> {
99    fn visit_token_stream(&mut self, tokens: &mut TokenStream2) -> bool {
100        let (mut out, mut modified) = (TokenStream2::new(), false);
101        visit_token_stream_impl(self, tokens.clone(), &mut modified, &mut out);
102
103        if modified {
104            *tokens = out;
105        }
106
107        modified
108    }
109}
110
111impl VisitMut for Scrub<'_> {
112    fn visit_expr_mut(&mut self, i: &mut syn::Expr) {
113        match i {
114            syn::Expr::Yield(yield_expr) => {
115                self.has_yielded = true;
116
117                syn::visit_mut::visit_expr_yield_mut(self, yield_expr);
118
119                let value_expr = yield_expr.expr.as_ref().unwrap_or(&self.unit);
120
121                // let ident = &self.yielder;
122
123                let yield_expr = if self.is_try {
124                    quote! { __yield_tx.send(::core::result::Result::Ok(#value_expr)).await }
125                } else {
126                    quote! { __yield_tx.send(#value_expr).await }
127                };
128                *i = syn::parse_quote! {
129                    {
130                        #[allow(unreachable_code)]
131                        if false {
132                            break '__async_stream_private_check_scope (loop {});
133                        }
134                        #yield_expr
135                    }
136                };
137            }
138            syn::Expr::Try(try_expr) => {
139                syn::visit_mut::visit_expr_try_mut(self, try_expr);
140                // let ident = &self.yielder;
141                let e = &try_expr.expr;
142
143                *i = syn::parse_quote! {
144                    match #e {
145                        ::core::result::Result::Ok(v) => v,
146                        ::core::result::Result::Err(e) => {
147                            __yield_tx.send(::core::result::Result::Err(e.into())).await;
148                            return;
149                        }
150                    }
151                };
152            }
153            syn::Expr::Closure(_) | syn::Expr::Async(_) => {
154                // Don't transform inner closures or async blocks.
155            }
156            syn::Expr::ForLoop(expr) => {
157                syn::visit_mut::visit_expr_for_loop_mut(self, expr);
158                // TODO: Should we allow other attributes?
159                if expr.attrs.len() != 1 || !expr.attrs[0].meta.path().is_ident(AWAIT_ATTR_NAME) {
160                    return;
161                }
162                let syn::ExprForLoop {
163                    attrs,
164                    label,
165                    pat,
166                    expr,
167                    body,
168                    ..
169                } = expr;
170
171                attrs.pop().unwrap();
172
173                let crate_path = self.crate_path;
174                *i = syn::parse_quote! {{
175                    let mut __pinned = #expr;
176                    let mut __pinned = unsafe {
177                        ::core::pin::Pin::new_unchecked(&mut __pinned)
178                    };
179                    #label
180                    loop {
181                        let #pat = match #crate_path::__private::next(&mut __pinned).await {
182                            ::core::option::Option::Some(e) => e,
183                            ::core::option::Option::None => break,
184                        };
185                        #body
186                    }
187                }}
188            }
189            _ => syn::visit_mut::visit_expr_mut(self, i),
190        }
191    }
192
193    fn visit_macro_mut(&mut self, mac: &mut syn::Macro) {
194        let mac_ident = mac.path.segments.last().map(|p| &p.ident);
195        if mac_ident.map_or(false, |i| i == "stream" || i == "try_stream") {
196            return;
197        }
198
199        self.visit_token_stream(&mut mac.tokens);
200    }
201
202    fn visit_item_mut(&mut self, i: &mut syn::Item) {
203        // Recurse into macros but otherwise don't transform inner items.
204        if let syn::Item::Macro(i) = i {
205            self.visit_macro_mut(&mut i.mac);
206        }
207    }
208}
209
210/// The first token tree in the stream must be a group containing the path to the `async-stream`
211/// crate.
212#[proc_macro]
213#[doc(hidden)]
214pub fn stream_inner(input: TokenStream) -> TokenStream {
215    let (crate_path, mut stmts) = match parse_input(input) {
216        Ok(x) => x,
217        Err(e) => return e.to_compile_error().into(),
218    };
219
220    let mut scrub = Scrub::new(false, &crate_path);
221
222    for stmt in &mut stmts {
223        scrub.visit_stmt_mut(stmt);
224    }
225
226    let dummy_yield = if scrub.has_yielded {
227        None
228    } else {
229        Some(quote!(if false {
230            __yield_tx.send(()).await;
231        }))
232    };
233
234    quote!({
235        let (mut __yield_tx, __yield_rx) = unsafe { #crate_path::__private::yielder::pair() };
236        #crate_path::__private::AsyncStream::new(__yield_rx, async move {
237            '__async_stream_private_check_scope: {
238                #dummy_yield
239                #(#stmts)*
240            }
241        })
242    })
243    .into()
244}
245
246/// The first token tree in the stream must be a group containing the path to the `async-stream`
247/// crate.
248#[proc_macro]
249#[doc(hidden)]
250pub fn try_stream_inner(input: TokenStream) -> TokenStream {
251    let (crate_path, mut stmts) = match parse_input(input) {
252        Ok(x) => x,
253        Err(e) => return e.to_compile_error().into(),
254    };
255
256    let mut scrub = Scrub::new(true, &crate_path);
257
258    for stmt in &mut stmts {
259        scrub.visit_stmt_mut(stmt);
260    }
261
262    let dummy_yield = if scrub.has_yielded {
263        None
264    } else {
265        Some(quote!(if false {
266            __yield_tx.send(()).await;
267        }))
268    };
269
270    quote!({
271        let (mut __yield_tx, __yield_rx) = unsafe { #crate_path::__private::yielder::pair() };
272        #crate_path::__private::AsyncStream::new(__yield_rx, async move {
273            '__async_stream_private_check_scope: {
274                #dummy_yield
275                #(#stmts)*
276            }
277        })
278    })
279    .into()
280}
281
282// syn 2.0 wont parse `#[await] for x in xs {}`
283// because `await` is a keyword, use `await_` instead
284const AWAIT_ATTR_NAME: &str = "await_";
285
286/// Replace `for await` with `#[await] for`, which will be later transformed into a `next` loop.
287fn replace_for_await(input: impl IntoIterator<Item = TokenTree>) -> TokenStream2 {
288    let mut input = input.into_iter().peekable();
289    let mut tokens = Vec::new();
290
291    while let Some(token) = input.next() {
292        match token {
293            TokenTree::Ident(ident) => {
294                match input.peek() {
295                    Some(TokenTree::Ident(next)) if ident == "for" && next == "await" => {
296                        let next_span = next.span();
297                        let next = syn::Ident::new(AWAIT_ATTR_NAME, next_span);
298                        tokens.extend(quote!(#[#next]));
299                        let _ = input.next();
300                    }
301                    _ => {}
302                }
303                tokens.push(ident.into());
304            }
305            TokenTree::Group(group) => {
306                let stream = replace_for_await(group.stream());
307                let mut new_group = Group::new(group.delimiter(), stream);
308                new_group.set_span(group.span());
309                tokens.push(new_group.into());
310            }
311            _ => tokens.push(token),
312        }
313    }
314
315    tokens.into_iter().collect()
316}