tower_lsp/
service.rs

1//! Service abstraction for language servers.
2
3pub use self::client::{Client, ClientSocket, RequestStream, ResponseSink};
4
5pub(crate) use self::pending::Pending;
6pub(crate) use self::state::{ServerState, State};
7
8use std::fmt::{self, Debug, Display, Formatter};
9use std::sync::Arc;
10use std::task::{Context, Poll};
11
12use futures::future::{self, BoxFuture, FutureExt};
13use serde_json::Value;
14use tower::Service;
15
16use crate::jsonrpc::{
17    Error, ErrorCode, FromParams, IntoResponse, Method, Request, Response, Router,
18};
19use crate::LanguageServer;
20
21pub(crate) mod layers;
22
23mod client;
24mod pending;
25mod state;
26
27/// Error that occurs when attempting to call the language server after it has already exited.
28#[derive(Clone, Debug, Eq, PartialEq)]
29pub struct ExitedError(());
30
31impl std::error::Error for ExitedError {}
32
33impl Display for ExitedError {
34    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
35        f.write_str("language server has exited")
36    }
37}
38
39/// Service abstraction for the Language Server Protocol.
40///
41/// This service takes an incoming JSON-RPC message as input and produces an outgoing message as
42/// output. If the incoming message is a server notification or a client response, then the
43/// corresponding response will be `None`.
44///
45/// This implements [`tower::Service`] in order to remain independent from the underlying transport
46/// and to facilitate further abstraction with middleware.
47///
48/// Pending requests can be canceled by issuing a [`$/cancelRequest`] notification.
49///
50/// [`$/cancelRequest`]: https://microsoft.github.io/language-server-protocol/specification#cancelRequest
51///
52/// The service shuts down and stops serving requests after the [`exit`] notification is received.
53///
54/// [`exit`]: https://microsoft.github.io/language-server-protocol/specification#exit
55#[derive(Debug)]
56pub struct LspService<S> {
57    inner: Router<S, ExitedError>,
58    state: Arc<ServerState>,
59}
60
61impl<S: LanguageServer> LspService<S> {
62    /// Creates a new `LspService` with the given server backend, also returning a channel for
63    /// server-to-client communication.
64    pub fn new<F>(init: F) -> (Self, ClientSocket)
65    where
66        F: FnOnce(Client) -> S,
67    {
68        LspService::build(init).finish()
69    }
70
71    /// Starts building a new `LspService`.
72    ///
73    /// Returns an `LspServiceBuilder`, which allows adding custom JSON-RPC methods to the server.
74    pub fn build<F>(init: F) -> LspServiceBuilder<S>
75    where
76        F: FnOnce(Client) -> S,
77    {
78        let state = Arc::new(ServerState::new());
79
80        let (client, socket) = Client::new(state.clone());
81        let inner = Router::new(init(client.clone()));
82        let pending = Arc::new(Pending::new());
83
84        LspServiceBuilder {
85            inner: crate::generated::register_lsp_methods(
86                inner,
87                state.clone(),
88                pending.clone(),
89                client,
90            ),
91            state,
92            pending,
93            socket,
94        }
95    }
96
97    /// Returns a reference to the inner server.
98    pub fn inner(&self) -> &S {
99        self.inner.inner()
100    }
101}
102
103impl<S: LanguageServer> Service<Request> for LspService<S> {
104    type Response = Option<Response>;
105    type Error = ExitedError;
106    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
107
108    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
109        match self.state.get() {
110            State::Initializing => Poll::Pending,
111            State::Exited => Poll::Ready(Err(ExitedError(()))),
112            _ => self.inner.poll_ready(cx),
113        }
114    }
115
116    fn call(&mut self, req: Request) -> Self::Future {
117        if self.state.get() == State::Exited {
118            return future::err(ExitedError(())).boxed();
119        }
120
121        let fut = self.inner.call(req);
122
123        Box::pin(async move {
124            let response = fut.await?;
125
126            match response.as_ref().and_then(|res| res.error()) {
127                Some(Error {
128                    code: ErrorCode::MethodNotFound,
129                    data: Some(Value::String(m)),
130                    ..
131                }) if m.starts_with("$/") => Ok(None),
132                _ => Ok(response),
133            }
134        })
135    }
136}
137
138/// A builder to customize the properties of an `LspService`.
139///
140/// To construct an `LspServiceBuilder`, refer to [`LspService::build`].
141pub struct LspServiceBuilder<S> {
142    inner: Router<S, ExitedError>,
143    state: Arc<ServerState>,
144    pending: Arc<Pending>,
145    socket: ClientSocket,
146}
147
148impl<S: LanguageServer> LspServiceBuilder<S> {
149    /// Defines a custom JSON-RPC request or notification with the given method `name` and handler.
150    ///
151    /// # Handler varieties
152    ///
153    /// Fundamentally, any inherent `async fn(&self)` method defined directly on the language
154    /// server backend could be considered a valid method handler.
155    ///
156    /// Handlers may optionally include a single `params` argument. This argument may be of any
157    /// type that implements [`Serialize`](serde::Serialize).
158    ///
159    /// Handlers which return `()` are treated as **notifications**, while those which return
160    /// [`jsonrpc::Result<T>`](crate::jsonrpc::Result) are treated as **requests**.
161    ///
162    /// Similar to the `params` argument, the `T` in the `Result<T>` return values may be of any
163    /// type which implements [`DeserializeOwned`](serde::de::DeserializeOwned). Additionally, this
164    /// type _must_ be convertible into a [`serde_json::Value`] using [`serde_json::to_value`]. If
165    /// this latter constraint is not met, the client will receive a JSON-RPC error response with
166    /// code `-32603` (Internal Error) instead of the expected response.
167    ///
168    /// # Examples
169    ///
170    /// ```rust
171    /// use serde_json::{json, Value};
172    /// use tower_lsp::jsonrpc::Result;
173    /// use tower_lsp::lsp_types::*;
174    /// use tower_lsp::{LanguageServer, LspService};
175    ///
176    /// struct Mock;
177    ///
178    /// // Implementation of `LanguageServer` omitted...
179    /// # #[tower_lsp::async_trait]
180    /// # impl LanguageServer for Mock {
181    /// #     async fn initialize(&self, _: InitializeParams) -> Result<InitializeResult> {
182    /// #         Ok(InitializeResult::default())
183    /// #     }
184    /// #
185    /// #     async fn shutdown(&self) -> Result<()> {
186    /// #         Ok(())
187    /// #     }
188    /// # }
189    ///
190    /// impl Mock {
191    ///     async fn request(&self) -> Result<i32> {
192    ///         Ok(123)
193    ///     }
194    ///
195    ///     async fn request_params(&self, params: Vec<String>) -> Result<Value> {
196    ///         Ok(json!({"num_elems":params.len()}))
197    ///     }
198    ///
199    ///     async fn notification(&self) {
200    ///         // ...
201    ///     }
202    ///
203    ///     async fn notification_params(&self, params: Value) {
204    ///         // ...
205    /// #       let _ = params;
206    ///     }
207    /// }
208    ///
209    /// let (service, socket) = LspService::build(|_| Mock)
210    ///     .custom_method("custom/request", Mock::request)
211    ///     .custom_method("custom/requestParams", Mock::request_params)
212    ///     .custom_method("custom/notification", Mock::notification)
213    ///     .custom_method("custom/notificationParams", Mock::notification_params)
214    ///     .finish();
215    /// ```
216    pub fn custom_method<P, R, F>(mut self, name: &'static str, callback: F) -> Self
217    where
218        P: FromParams,
219        R: IntoResponse,
220        F: for<'a> Method<&'a S, P, R> + Clone + Send + Sync + 'static,
221    {
222        let layer = layers::Normal::new(self.state.clone(), self.pending.clone());
223        self.inner.method(name, callback, layer);
224        self
225    }
226
227    /// Constructs the `LspService` and returns it, along with a channel for server-to-client
228    /// communication.
229    pub fn finish(self) -> (LspService<S>, ClientSocket) {
230        let LspServiceBuilder {
231            inner,
232            state,
233            socket,
234            ..
235        } = self;
236
237        (LspService { inner, state }, socket)
238    }
239}
240
241impl<S: Debug> Debug for LspServiceBuilder<S> {
242    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
243        f.debug_struct("LspServiceBuilder")
244            .field("inner", &self.inner)
245            .finish_non_exhaustive()
246    }
247}
248
249#[cfg(test)]
250mod tests {
251    use async_trait::async_trait;
252    use lsp_types::*;
253    use serde_json::json;
254    use tower::ServiceExt;
255
256    use super::*;
257    use crate::jsonrpc::Result;
258
259    #[derive(Debug)]
260    struct Mock;
261
262    #[async_trait]
263    impl LanguageServer for Mock {
264        async fn initialize(&self, _: InitializeParams) -> Result<InitializeResult> {
265            Ok(InitializeResult::default())
266        }
267
268        async fn shutdown(&self) -> Result<()> {
269            Ok(())
270        }
271
272        // This handler should never resolve...
273        async fn code_action_resolve(&self, _: CodeAction) -> Result<CodeAction> {
274            future::pending().await
275        }
276    }
277
278    impl Mock {
279        async fn custom_request(&self, params: i32) -> Result<i32> {
280            Ok(params)
281        }
282    }
283
284    fn initialize_request(id: i64) -> Request {
285        Request::build("initialize")
286            .params(json!({"capabilities":{}}))
287            .id(id)
288            .finish()
289    }
290
291    #[tokio::test(flavor = "current_thread")]
292    async fn initializes_only_once() {
293        let (mut service, _) = LspService::new(|_| Mock);
294
295        let request = initialize_request(1);
296
297        let response = service.ready().await.unwrap().call(request.clone()).await;
298        let ok = Response::from_ok(1.into(), json!({"capabilities":{}}));
299        assert_eq!(response, Ok(Some(ok)));
300
301        let response = service.ready().await.unwrap().call(request).await;
302        let err = Response::from_error(1.into(), Error::invalid_request());
303        assert_eq!(response, Ok(Some(err)));
304    }
305
306    #[tokio::test(flavor = "current_thread")]
307    async fn refuses_requests_after_shutdown() {
308        let (mut service, _) = LspService::new(|_| Mock);
309
310        let initialize = initialize_request(1);
311        let response = service.ready().await.unwrap().call(initialize).await;
312        let ok = Response::from_ok(1.into(), json!({"capabilities":{}}));
313        assert_eq!(response, Ok(Some(ok)));
314
315        let shutdown = Request::build("shutdown").id(1).finish();
316        let response = service.ready().await.unwrap().call(shutdown.clone()).await;
317        let ok = Response::from_ok(1.into(), json!(null));
318        assert_eq!(response, Ok(Some(ok)));
319
320        let response = service.ready().await.unwrap().call(shutdown).await;
321        let err = Response::from_error(1.into(), Error::invalid_request());
322        assert_eq!(response, Ok(Some(err)));
323    }
324
325    #[tokio::test(flavor = "current_thread")]
326    async fn exit_notification() {
327        let (mut service, _) = LspService::new(|_| Mock);
328
329        let exit = Request::build("exit").finish();
330        let response = service.ready().await.unwrap().call(exit.clone()).await;
331        assert_eq!(response, Ok(None));
332
333        let ready = future::poll_fn(|cx| service.poll_ready(cx)).await;
334        assert_eq!(ready, Err(ExitedError(())));
335        assert_eq!(service.call(exit).await, Err(ExitedError(())));
336    }
337
338    #[tokio::test(flavor = "current_thread")]
339    async fn cancels_pending_requests() {
340        let (mut service, _) = LspService::new(|_| Mock);
341
342        let initialize = initialize_request(1);
343        let response = service.ready().await.unwrap().call(initialize).await;
344        let ok = Response::from_ok(1.into(), json!({"capabilities":{}}));
345        assert_eq!(response, Ok(Some(ok)));
346
347        let pending_request = Request::build("codeAction/resolve")
348            .params(json!({"title":""}))
349            .id(1)
350            .finish();
351
352        let cancel_request = Request::build("$/cancelRequest")
353            .params(json!({"id":1i32}))
354            .finish();
355
356        let pending_fut = service.ready().await.unwrap().call(pending_request);
357        let cancel_fut = service.ready().await.unwrap().call(cancel_request);
358        let (pending_response, cancel_response) = futures::join!(pending_fut, cancel_fut);
359
360        let canceled = Response::from_error(1.into(), Error::request_cancelled());
361        assert_eq!(pending_response, Ok(Some(canceled)));
362        assert_eq!(cancel_response, Ok(None));
363    }
364
365    #[tokio::test(flavor = "current_thread")]
366    async fn serves_custom_requests() {
367        let (mut service, _) = LspService::build(|_| Mock)
368            .custom_method("custom", Mock::custom_request)
369            .finish();
370
371        let initialize = initialize_request(1);
372        let response = service.ready().await.unwrap().call(initialize).await;
373        let ok = Response::from_ok(1.into(), json!({"capabilities":{}}));
374        assert_eq!(response, Ok(Some(ok)));
375
376        let custom = Request::build("custom").params(123i32).id(1).finish();
377        let response = service.ready().await.unwrap().call(custom).await;
378        let ok = Response::from_ok(1.into(), json!(123i32));
379        assert_eq!(response, Ok(Some(ok)));
380    }
381
382    #[tokio::test(flavor = "current_thread")]
383    async fn get_inner() {
384        let (service, _) = LspService::build(|_| Mock).finish();
385
386        service
387            .inner()
388            .initialize(InitializeParams::default())
389            .await
390            .unwrap();
391    }
392}