tower_lsp_macros/
lib.rs

1//! Internal procedural macros for [`tower-lsp`](https://docs.rs/tower-lsp).
2//!
3//! This crate should not be used directly.
4
5extern crate proc_macro;
6
7use proc_macro::TokenStream;
8use quote::quote;
9use syn::{parse_macro_input, FnArg, ItemTrait, LitStr, ReturnType, TraitItem};
10
11/// Macro for generating LSP server implementation from [`lsp-types`](https://docs.rs/lsp-types).
12///
13/// This procedural macro annotates the `tower_lsp::LanguageServer` trait and generates a
14/// corresponding `register_lsp_methods()` function which registers all the methods on that trait
15/// as RPC handlers.
16#[proc_macro_attribute]
17pub fn rpc(attr: TokenStream, item: TokenStream) -> TokenStream {
18    // Attribute will be parsed later in `parse_method_calls()`.
19    if !attr.is_empty() {
20        return item;
21    }
22
23    let lang_server_trait = parse_macro_input!(item as ItemTrait);
24    let method_calls = parse_method_calls(&lang_server_trait);
25    let req_types_and_router_fn = gen_server_router(&lang_server_trait.ident, &method_calls);
26
27    let tokens = quote! {
28        #lang_server_trait
29        #req_types_and_router_fn
30    };
31
32    tokens.into()
33}
34
35struct MethodCall<'a> {
36    rpc_name: String,
37    handler_name: &'a syn::Ident,
38    params: Option<&'a syn::Type>,
39    result: Option<&'a syn::Type>,
40}
41
42fn parse_method_calls(lang_server_trait: &ItemTrait) -> Vec<MethodCall> {
43    let mut calls = Vec::new();
44
45    for item in &lang_server_trait.items {
46        let method = match item {
47            TraitItem::Fn(m) => m,
48            _ => continue,
49        };
50
51        let attr = method
52            .attrs
53            .iter()
54            .find(|attr| attr.meta.path().is_ident("rpc"))
55            .expect("expected `#[rpc(name = \"foo\")]` attribute");
56
57        let mut rpc_name = String::new();
58        attr.parse_nested_meta(|meta| {
59            if meta.path.is_ident("name") {
60                let s: LitStr = meta.value().and_then(|v| v.parse())?;
61                rpc_name = s.value();
62                Ok(())
63            } else {
64                Err(meta.error("expected `name` identifier in `#[rpc]`"))
65            }
66        })
67        .unwrap();
68
69        let params = method.sig.inputs.iter().nth(1).and_then(|arg| match arg {
70            FnArg::Typed(pat) => Some(&*pat.ty),
71            _ => None,
72        });
73
74        let result = match &method.sig.output {
75            ReturnType::Default => None,
76            ReturnType::Type(_, ty) => Some(&**ty),
77        };
78
79        calls.push(MethodCall {
80            rpc_name,
81            handler_name: &method.sig.ident,
82            params,
83            result,
84        });
85    }
86
87    calls
88}
89
90fn gen_server_router(trait_name: &syn::Ident, methods: &[MethodCall]) -> proc_macro2::TokenStream {
91    let route_registrations: proc_macro2::TokenStream = methods
92        .iter()
93        .map(|method| {
94            let rpc_name = &method.rpc_name;
95            let handler = &method.handler_name;
96
97            let layer = match &rpc_name[..] {
98                "initialize" => quote! { layers::Initialize::new(state.clone(), pending.clone()) },
99                "shutdown" => quote! { layers::Shutdown::new(state.clone(), pending.clone()) },
100                _ => quote! { layers::Normal::new(state.clone(), pending.clone()) },
101            };
102
103            // NOTE: In a perfect world, we could simply loop over each `MethodCall` and emit
104            // `router.method(#rpc_name, S::#handler);` for each. While such an approach
105            // works for inherent async functions and methods, it breaks with `async-trait` methods
106            // due to this unfortunate `rustc` bug:
107            //
108            // https://github.com/rust-lang/rust/issues/64552
109            //
110            // As a workaround, we wrap each `async-trait` method in a regular `async fn` before
111            // passing it to `.method`, as documented in this GitHub issue:
112            //
113            // https://github.com/dtolnay/async-trait/issues/167
114            match (method.params, method.result) {
115                (Some(params), Some(result)) => quote! {
116                    async fn #handler<S: #trait_name>(server: &S, params: #params) -> #result {
117                        server.#handler(params).await
118                    }
119                    router.method(#rpc_name, #handler, #layer);
120                },
121                (None, Some(result)) => quote! {
122                    async fn #handler<S: #trait_name>(server: &S) -> #result {
123                        server.#handler().await
124                    }
125                    router.method(#rpc_name, #handler, #layer);
126                },
127                (Some(params), None) => quote! {
128                    async fn #handler<S: #trait_name>(server: &S, params: #params) {
129                        server.#handler(params).await
130                    }
131                    router.method(#rpc_name, #handler, #layer);
132                },
133                (None, None) => quote! {
134                    async fn #handler<S: #trait_name>(server: &S) {
135                        server.#handler().await
136                    }
137                    router.method(#rpc_name, #handler, #layer);
138                },
139            }
140        })
141        .collect();
142
143    quote! {
144        mod generated {
145            use std::sync::Arc;
146            use std::future::{Future, Ready};
147
148            use lsp_types::*;
149            use lsp_types::notification::*;
150            use lsp_types::request::*;
151            use serde_json::Value;
152
153            use super::#trait_name;
154            use crate::jsonrpc::{Result, Router};
155            use crate::service::{layers, Client, Pending, ServerState, State, ExitedError};
156
157            fn cancel_request(params: CancelParams, p: &Pending) -> Ready<()> {
158                p.cancel(&params.id.into());
159                std::future::ready(())
160            }
161
162            pub(crate) fn register_lsp_methods<S>(
163                mut router: Router<S, ExitedError>,
164                state: Arc<ServerState>,
165                pending: Arc<Pending>,
166                client: Client,
167            ) -> Router<S, ExitedError>
168            where
169                S: #trait_name,
170            {
171                #route_registrations
172
173                let p = pending.clone();
174                router.method(
175                    "$/cancelRequest",
176                    move |_: &S, params| cancel_request(params, &p),
177                    tower::layer::util::Identity::new(),
178                );
179                router.method(
180                    "exit",
181                    |_: &S| std::future::ready(()),
182                    layers::Exit::new(state.clone(), pending, client.clone()),
183                );
184
185                router
186            }
187        }
188    }
189}