tower_lsp/service/
layers.rs

1//! Assorted middleware that implements LSP server semantics.
2
3use std::marker::PhantomData;
4use std::sync::Arc;
5use std::task::{Context, Poll};
6
7use futures::future::{self, BoxFuture, FutureExt};
8use tower::{Layer, Service};
9use tracing::{info, warn};
10
11use super::ExitedError;
12use crate::jsonrpc::{not_initialized_error, Error, Id, Request, Response};
13
14use super::client::Client;
15use super::pending::Pending;
16use super::state::{ServerState, State};
17
18/// Middleware which implements `initialize` request semantics.
19///
20/// # Specification
21///
22/// https://microsoft.github.io/language-server-protocol/specification#initialize
23pub struct Initialize {
24    state: Arc<ServerState>,
25    pending: Arc<Pending>,
26}
27
28impl Initialize {
29    pub fn new(state: Arc<ServerState>, pending: Arc<Pending>) -> Self {
30        Initialize { state, pending }
31    }
32}
33
34impl<S> Layer<S> for Initialize {
35    type Service = InitializeService<S>;
36
37    fn layer(&self, inner: S) -> Self::Service {
38        InitializeService {
39            inner: Cancellable::new(inner, self.pending.clone()),
40            state: self.state.clone(),
41        }
42    }
43}
44
45/// Service created from [`Initialize`] layer.
46pub struct InitializeService<S> {
47    inner: Cancellable<S>,
48    state: Arc<ServerState>,
49}
50
51impl<S> Service<Request> for InitializeService<S>
52where
53    S: Service<Request, Response = Option<Response>, Error = ExitedError>,
54    S::Future: Send + 'static,
55{
56    type Response = S::Response;
57    type Error = S::Error;
58    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
59
60    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
61        self.inner.poll_ready(cx)
62    }
63
64    fn call(&mut self, req: Request) -> Self::Future {
65        if self.state.get() == State::Uninitialized {
66            let state = self.state.clone();
67            let fut = self.inner.call(req);
68
69            Box::pin(async move {
70                let response = fut.await?;
71
72                match &response {
73                    Some(res) if res.is_ok() => state.set(State::Initialized),
74                    _ => state.set(State::Uninitialized),
75                }
76
77                Ok(response)
78            })
79        } else {
80            warn!("received duplicate `initialize` request, ignoring");
81            let (_, id, _) = req.into_parts();
82            future::ok(id.map(|id| Response::from_error(id, Error::invalid_request()))).boxed()
83        }
84    }
85}
86
87/// Middleware which implements `shutdown` request semantics.
88///
89/// # Specification
90///
91/// https://microsoft.github.io/language-server-protocol/specification#shutdown
92pub struct Shutdown {
93    state: Arc<ServerState>,
94    pending: Arc<Pending>,
95}
96
97impl Shutdown {
98    pub fn new(state: Arc<ServerState>, pending: Arc<Pending>) -> Self {
99        Shutdown { state, pending }
100    }
101}
102
103impl<S> Layer<S> for Shutdown {
104    type Service = ShutdownService<S>;
105
106    fn layer(&self, inner: S) -> Self::Service {
107        ShutdownService {
108            inner: Cancellable::new(inner, self.pending.clone()),
109            state: self.state.clone(),
110        }
111    }
112}
113
114/// Service created from [`Shutdown`] layer.
115pub struct ShutdownService<S> {
116    inner: Cancellable<S>,
117    state: Arc<ServerState>,
118}
119
120impl<S> Service<Request> for ShutdownService<S>
121where
122    S: Service<Request, Response = Option<Response>, Error = ExitedError>,
123    S::Future: Into<BoxFuture<'static, Result<Option<Response>, S::Error>>> + Send + 'static,
124{
125    type Response = S::Response;
126    type Error = S::Error;
127    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
128
129    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
130        self.inner.poll_ready(cx)
131    }
132
133    fn call(&mut self, req: Request) -> Self::Future {
134        match self.state.get() {
135            State::Initialized => {
136                info!("shutdown request received, shutting down");
137                self.state.set(State::ShutDown);
138                self.inner.call(req)
139            }
140            cur_state => {
141                let (_, id, _) = req.into_parts();
142                future::ok(not_initialized_response(id, cur_state)).boxed()
143            }
144        }
145    }
146}
147
148/// Middleware which implements `exit` notification semantics.
149///
150/// # Specification
151///
152/// https://microsoft.github.io/language-server-protocol/specification#exit
153pub struct Exit {
154    state: Arc<ServerState>,
155    pending: Arc<Pending>,
156    client: Client,
157}
158
159impl Exit {
160    pub fn new(state: Arc<ServerState>, pending: Arc<Pending>, client: Client) -> Self {
161        Exit {
162            state,
163            pending,
164            client,
165        }
166    }
167}
168
169impl<S> Layer<S> for Exit {
170    type Service = ExitService<S>;
171
172    fn layer(&self, _: S) -> Self::Service {
173        ExitService {
174            state: self.state.clone(),
175            pending: self.pending.clone(),
176            client: self.client.clone(),
177            _marker: PhantomData,
178        }
179    }
180}
181
182/// Service created from [`Exit`] layer.
183pub struct ExitService<S> {
184    state: Arc<ServerState>,
185    pending: Arc<Pending>,
186    client: Client,
187    _marker: PhantomData<S>,
188}
189
190impl<S> Service<Request> for ExitService<S> {
191    type Response = Option<Response>;
192    type Error = ExitedError;
193    type Future = future::Ready<Result<Self::Response, Self::Error>>;
194
195    fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
196        if self.state.get() == State::Exited {
197            Poll::Ready(Err(ExitedError(())))
198        } else {
199            Poll::Ready(Ok(()))
200        }
201    }
202
203    fn call(&mut self, _: Request) -> Self::Future {
204        info!("exit notification received, stopping");
205        self.state.set(State::Exited);
206        self.pending.cancel_all();
207        self.client.close();
208        future::ok(None)
209    }
210}
211
212/// Middleware which implements LSP semantics for all other kinds of requests.
213pub struct Normal {
214    state: Arc<ServerState>,
215    pending: Arc<Pending>,
216}
217
218impl Normal {
219    pub fn new(state: Arc<ServerState>, pending: Arc<Pending>) -> Self {
220        Normal { state, pending }
221    }
222}
223
224impl<S> Layer<S> for Normal {
225    type Service = NormalService<S>;
226
227    fn layer(&self, inner: S) -> Self::Service {
228        NormalService {
229            inner: Cancellable::new(inner, self.pending.clone()),
230            state: self.state.clone(),
231        }
232    }
233}
234
235/// Service created from [`Normal`] layer.
236pub struct NormalService<S> {
237    inner: Cancellable<S>,
238    state: Arc<ServerState>,
239}
240
241impl<S> Service<Request> for NormalService<S>
242where
243    S: Service<Request, Response = Option<Response>, Error = ExitedError>,
244    S::Future: Into<BoxFuture<'static, Result<Option<Response>, S::Error>>> + Send + 'static,
245{
246    type Response = S::Response;
247    type Error = S::Error;
248    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
249
250    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
251        self.inner.poll_ready(cx)
252    }
253
254    fn call(&mut self, req: Request) -> Self::Future {
255        match self.state.get() {
256            State::Initialized => self.inner.call(req),
257            cur_state => {
258                let (_, id, _) = req.into_parts();
259                future::ok(not_initialized_response(id, cur_state)).boxed()
260            }
261        }
262    }
263}
264
265/// Wraps an inner service `S` and implements `$/cancelRequest` semantics for all requests.
266///
267/// # Specification
268///
269/// https://microsoft.github.io/language-server-protocol/specification#cancelRequest
270struct Cancellable<S> {
271    inner: S,
272    pending: Arc<Pending>,
273}
274
275impl<S> Cancellable<S> {
276    fn new(inner: S, pending: Arc<Pending>) -> Self {
277        Cancellable { inner, pending }
278    }
279}
280
281impl<S> Service<Request> for Cancellable<S>
282where
283    S: Service<Request, Response = Option<Response>, Error = ExitedError>,
284    S::Future: Send + 'static,
285{
286    type Response = S::Response;
287    type Error = S::Error;
288    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
289
290    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
291        self.inner.poll_ready(cx)
292    }
293
294    fn call(&mut self, req: Request) -> Self::Future {
295        match req.id().cloned() {
296            Some(id) => self.pending.execute(id, self.inner.call(req)).boxed(),
297            None => self.inner.call(req).boxed(),
298        }
299    }
300}
301
302fn not_initialized_response(id: Option<Id>, server_state: State) -> Option<Response> {
303    let id = id?;
304    let error = match server_state {
305        State::Uninitialized | State::Initializing => not_initialized_error(),
306        _ => Error::invalid_request(),
307    };
308
309    Some(Response::from_error(id, error))
310}
311
312// TODO: Add some `tower-test` middleware tests for each middleware.