1pub use self::client::{Client, ClientSocket, RequestStream, ResponseSink};
4
5pub(crate) use self::pending::Pending;
6pub(crate) use self::state::{ServerState, State};
7
8use std::fmt::{self, Debug, Display, Formatter};
9use std::sync::Arc;
10use std::task::{Context, Poll};
11
12use futures::future::{self, BoxFuture, FutureExt};
13use serde_json::Value;
14use tower::Service;
15
16use crate::jsonrpc::{
17 Error, ErrorCode, FromParams, IntoResponse, Method, Request, Response, Router,
18};
19use crate::LanguageServer;
20
21pub(crate) mod layers;
22
23mod client;
24mod pending;
25mod state;
26
27#[derive(Clone, Debug, Eq, PartialEq)]
29pub struct ExitedError(());
30
31impl std::error::Error for ExitedError {}
32
33impl Display for ExitedError {
34 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
35 f.write_str("language server has exited")
36 }
37}
38
39#[derive(Debug)]
56pub struct LspService<S> {
57 inner: Router<S, ExitedError>,
58 state: Arc<ServerState>,
59}
60
61impl<S: LanguageServer> LspService<S> {
62 pub fn new<F>(init: F) -> (Self, ClientSocket)
65 where
66 F: FnOnce(Client) -> S,
67 {
68 LspService::build(init).finish()
69 }
70
71 pub fn build<F>(init: F) -> LspServiceBuilder<S>
75 where
76 F: FnOnce(Client) -> S,
77 {
78 let state = Arc::new(ServerState::new());
79
80 let (client, socket) = Client::new(state.clone());
81 let inner = Router::new(init(client.clone()));
82 let pending = Arc::new(Pending::new());
83
84 LspServiceBuilder {
85 inner: crate::generated::register_lsp_methods(
86 inner,
87 state.clone(),
88 pending.clone(),
89 client,
90 ),
91 state,
92 pending,
93 socket,
94 }
95 }
96
97 pub fn inner(&self) -> &S {
99 self.inner.inner()
100 }
101}
102
103impl<S: LanguageServer> Service<Request> for LspService<S> {
104 type Response = Option<Response>;
105 type Error = ExitedError;
106 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
107
108 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
109 match self.state.get() {
110 State::Initializing => Poll::Pending,
111 State::Exited => Poll::Ready(Err(ExitedError(()))),
112 _ => self.inner.poll_ready(cx),
113 }
114 }
115
116 fn call(&mut self, req: Request) -> Self::Future {
117 if self.state.get() == State::Exited {
118 return future::err(ExitedError(())).boxed();
119 }
120
121 let fut = self.inner.call(req);
122
123 Box::pin(async move {
124 let response = fut.await?;
125
126 match response.as_ref().and_then(|res| res.error()) {
127 Some(Error {
128 code: ErrorCode::MethodNotFound,
129 data: Some(Value::String(m)),
130 ..
131 }) if m.starts_with("$/") => Ok(None),
132 _ => Ok(response),
133 }
134 })
135 }
136}
137
138pub struct LspServiceBuilder<S> {
142 inner: Router<S, ExitedError>,
143 state: Arc<ServerState>,
144 pending: Arc<Pending>,
145 socket: ClientSocket,
146}
147
148impl<S: LanguageServer> LspServiceBuilder<S> {
149 pub fn custom_method<P, R, F>(mut self, name: &'static str, callback: F) -> Self
217 where
218 P: FromParams,
219 R: IntoResponse,
220 F: for<'a> Method<&'a S, P, R> + Clone + Send + Sync + 'static,
221 {
222 let layer = layers::Normal::new(self.state.clone(), self.pending.clone());
223 self.inner.method(name, callback, layer);
224 self
225 }
226
227 pub fn finish(self) -> (LspService<S>, ClientSocket) {
230 let LspServiceBuilder {
231 inner,
232 state,
233 socket,
234 ..
235 } = self;
236
237 (LspService { inner, state }, socket)
238 }
239}
240
241impl<S: Debug> Debug for LspServiceBuilder<S> {
242 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
243 f.debug_struct("LspServiceBuilder")
244 .field("inner", &self.inner)
245 .finish_non_exhaustive()
246 }
247}
248
249#[cfg(test)]
250mod tests {
251 use async_trait::async_trait;
252 use lsp_types::*;
253 use serde_json::json;
254 use tower::ServiceExt;
255
256 use super::*;
257 use crate::jsonrpc::Result;
258
259 #[derive(Debug)]
260 struct Mock;
261
262 #[async_trait]
263 impl LanguageServer for Mock {
264 async fn initialize(&self, _: InitializeParams) -> Result<InitializeResult> {
265 Ok(InitializeResult::default())
266 }
267
268 async fn shutdown(&self) -> Result<()> {
269 Ok(())
270 }
271
272 async fn code_action_resolve(&self, _: CodeAction) -> Result<CodeAction> {
274 future::pending().await
275 }
276 }
277
278 impl Mock {
279 async fn custom_request(&self, params: i32) -> Result<i32> {
280 Ok(params)
281 }
282 }
283
284 fn initialize_request(id: i64) -> Request {
285 Request::build("initialize")
286 .params(json!({"capabilities":{}}))
287 .id(id)
288 .finish()
289 }
290
291 #[tokio::test(flavor = "current_thread")]
292 async fn initializes_only_once() {
293 let (mut service, _) = LspService::new(|_| Mock);
294
295 let request = initialize_request(1);
296
297 let response = service.ready().await.unwrap().call(request.clone()).await;
298 let ok = Response::from_ok(1.into(), json!({"capabilities":{}}));
299 assert_eq!(response, Ok(Some(ok)));
300
301 let response = service.ready().await.unwrap().call(request).await;
302 let err = Response::from_error(1.into(), Error::invalid_request());
303 assert_eq!(response, Ok(Some(err)));
304 }
305
306 #[tokio::test(flavor = "current_thread")]
307 async fn refuses_requests_after_shutdown() {
308 let (mut service, _) = LspService::new(|_| Mock);
309
310 let initialize = initialize_request(1);
311 let response = service.ready().await.unwrap().call(initialize).await;
312 let ok = Response::from_ok(1.into(), json!({"capabilities":{}}));
313 assert_eq!(response, Ok(Some(ok)));
314
315 let shutdown = Request::build("shutdown").id(1).finish();
316 let response = service.ready().await.unwrap().call(shutdown.clone()).await;
317 let ok = Response::from_ok(1.into(), json!(null));
318 assert_eq!(response, Ok(Some(ok)));
319
320 let response = service.ready().await.unwrap().call(shutdown).await;
321 let err = Response::from_error(1.into(), Error::invalid_request());
322 assert_eq!(response, Ok(Some(err)));
323 }
324
325 #[tokio::test(flavor = "current_thread")]
326 async fn exit_notification() {
327 let (mut service, _) = LspService::new(|_| Mock);
328
329 let exit = Request::build("exit").finish();
330 let response = service.ready().await.unwrap().call(exit.clone()).await;
331 assert_eq!(response, Ok(None));
332
333 let ready = future::poll_fn(|cx| service.poll_ready(cx)).await;
334 assert_eq!(ready, Err(ExitedError(())));
335 assert_eq!(service.call(exit).await, Err(ExitedError(())));
336 }
337
338 #[tokio::test(flavor = "current_thread")]
339 async fn cancels_pending_requests() {
340 let (mut service, _) = LspService::new(|_| Mock);
341
342 let initialize = initialize_request(1);
343 let response = service.ready().await.unwrap().call(initialize).await;
344 let ok = Response::from_ok(1.into(), json!({"capabilities":{}}));
345 assert_eq!(response, Ok(Some(ok)));
346
347 let pending_request = Request::build("codeAction/resolve")
348 .params(json!({"title":""}))
349 .id(1)
350 .finish();
351
352 let cancel_request = Request::build("$/cancelRequest")
353 .params(json!({"id":1i32}))
354 .finish();
355
356 let pending_fut = service.ready().await.unwrap().call(pending_request);
357 let cancel_fut = service.ready().await.unwrap().call(cancel_request);
358 let (pending_response, cancel_response) = futures::join!(pending_fut, cancel_fut);
359
360 let canceled = Response::from_error(1.into(), Error::request_cancelled());
361 assert_eq!(pending_response, Ok(Some(canceled)));
362 assert_eq!(cancel_response, Ok(None));
363 }
364
365 #[tokio::test(flavor = "current_thread")]
366 async fn serves_custom_requests() {
367 let (mut service, _) = LspService::build(|_| Mock)
368 .custom_method("custom", Mock::custom_request)
369 .finish();
370
371 let initialize = initialize_request(1);
372 let response = service.ready().await.unwrap().call(initialize).await;
373 let ok = Response::from_ok(1.into(), json!({"capabilities":{}}));
374 assert_eq!(response, Ok(Some(ok)));
375
376 let custom = Request::build("custom").params(123i32).id(1).finish();
377 let response = service.ready().await.unwrap().call(custom).await;
378 let ok = Response::from_ok(1.into(), json!(123i32));
379 assert_eq!(response, Ok(Some(ok)));
380 }
381
382 #[tokio::test(flavor = "current_thread")]
383 async fn get_inner() {
384 let (service, _) = LspService::build(|_| Mock).finish();
385
386 service
387 .inner()
388 .initialize(InitializeParams::default())
389 .await
390 .unwrap();
391 }
392}