tower_lsp/
transport.rs

1//! Generic server for multiplexing bidirectional streams through a transport.
2
3#[cfg(feature = "runtime-agnostic")]
4use async_codec_lite::{FramedRead, FramedWrite};
5#[cfg(feature = "runtime-agnostic")]
6use futures::io::{AsyncRead, AsyncWrite};
7
8#[cfg(feature = "runtime-tokio")]
9use tokio::io::{AsyncRead, AsyncWrite};
10#[cfg(feature = "runtime-tokio")]
11use tokio_util::codec::{FramedRead, FramedWrite};
12
13use futures::channel::mpsc;
14use futures::{future, join, stream, FutureExt, Sink, SinkExt, Stream, StreamExt, TryFutureExt};
15use tower::Service;
16use tracing::error;
17
18use crate::codec::{LanguageServerCodec, ParseError};
19use crate::jsonrpc::{Error, Id, Message, Request, Response};
20use crate::service::{ClientSocket, RequestStream, ResponseSink};
21
22const DEFAULT_MAX_CONCURRENCY: usize = 4;
23const MESSAGE_QUEUE_SIZE: usize = 100;
24
25/// Trait implemented by client loopback sockets.
26///
27/// This socket handles the server-to-client half of the bidirectional communication stream.
28pub trait Loopback {
29    /// Yields a stream of pending server-to-client requests.
30    type RequestStream: Stream<Item = Request>;
31    /// Routes client-to-server responses back to the server.
32    type ResponseSink: Sink<Response> + Unpin;
33
34    /// Splits this socket into two halves capable of operating independently.
35    ///
36    /// The two halves returned implement the [`Stream`] and [`Sink`] traits, respectively.
37    fn split(self) -> (Self::RequestStream, Self::ResponseSink);
38}
39
40impl Loopback for ClientSocket {
41    type RequestStream = RequestStream;
42    type ResponseSink = ResponseSink;
43
44    #[inline]
45    fn split(self) -> (Self::RequestStream, Self::ResponseSink) {
46        self.split()
47    }
48}
49
50/// Server for processing requests and responses on standard I/O or TCP.
51#[derive(Debug)]
52pub struct Server<I, O, L = ClientSocket> {
53    stdin: I,
54    stdout: O,
55    loopback: L,
56    max_concurrency: usize,
57}
58
59impl<I, O, L> Server<I, O, L>
60where
61    I: AsyncRead + Unpin,
62    O: AsyncWrite,
63    L: Loopback,
64    <L::ResponseSink as Sink<Response>>::Error: std::error::Error,
65{
66    /// Creates a new `Server` with the given `stdin` and `stdout` handles.
67    pub fn new(stdin: I, stdout: O, socket: L) -> Self {
68        Server {
69            stdin,
70            stdout,
71            loopback: socket,
72            max_concurrency: DEFAULT_MAX_CONCURRENCY,
73        }
74    }
75
76    /// Sets the server concurrency limit to `max`.
77    ///
78    /// This setting specifies how many incoming requests may be processed concurrently. Setting
79    /// this value to `1` forces all requests to be processed sequentially, thereby implicitly
80    /// disabling support for the [`$/cancelRequest`] notification.
81    ///
82    /// [`$/cancelRequest`]: https://microsoft.github.io/language-server-protocol/specification#cancelRequest
83    ///
84    /// If not explicitly specified, `max` defaults to 4.
85    ///
86    /// # Preference over standard `tower` middleware
87    ///
88    /// The [`ConcurrencyLimit`] and [`Buffer`] middlewares provided by `tower` rely on
89    /// [`tokio::spawn`] in common usage, while this library aims to be executor agnostic and to
90    /// support exotic targets currently incompatible with `tokio`, such as WASM. As such, `Server`
91    /// includes its own concurrency facilities that don't require a global executor to be present.
92    ///
93    /// [`ConcurrencyLimit`]: https://docs.rs/tower/latest/tower/limit/concurrency/struct.ConcurrencyLimit.html
94    /// [`Buffer`]: https://docs.rs/tower/latest/tower/buffer/index.html
95    /// [`tokio::spawn`]: https://docs.rs/tokio/latest/tokio/fn.spawn.html
96    pub fn concurrency_level(mut self, max: usize) -> Self {
97        self.max_concurrency = max;
98        self
99    }
100
101    /// Spawns the service with messages read through `stdin` and responses written to `stdout`.
102    pub async fn serve<T>(self, mut service: T)
103    where
104        T: Service<Request, Response = Option<Response>> + Send + 'static,
105        T::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
106        T::Future: Send,
107    {
108        let (client_requests, mut client_responses) = self.loopback.split();
109        let (client_requests, client_abort) = stream::abortable(client_requests);
110        let (mut responses_tx, responses_rx) = mpsc::channel(0);
111        let (mut server_tasks_tx, server_tasks_rx) = mpsc::channel(MESSAGE_QUEUE_SIZE);
112
113        let mut framed_stdin = FramedRead::new(self.stdin, LanguageServerCodec::default());
114        let framed_stdout = FramedWrite::new(self.stdout, LanguageServerCodec::default());
115
116        let process_server_tasks = server_tasks_rx
117            .buffer_unordered(self.max_concurrency)
118            .filter_map(future::ready)
119            .map(|res| Ok(Message::Response(res)))
120            .forward(responses_tx.clone().sink_map_err(|_| unreachable!()))
121            .map(|_| ());
122
123        let print_output = stream::select(responses_rx, client_requests.map(Message::Request))
124            .map(Ok)
125            .forward(framed_stdout.sink_map_err(|e| error!("failed to encode message: {}", e)))
126            .map(|_| ());
127
128        let read_input = async {
129            while let Some(msg) = framed_stdin.next().await {
130                match msg {
131                    Ok(Message::Request(req)) => {
132                        if let Err(err) = future::poll_fn(|cx| service.poll_ready(cx)).await {
133                            error!("{}", display_sources(err.into().as_ref()));
134                            return;
135                        }
136
137                        let fut = service.call(req).unwrap_or_else(|err| {
138                            error!("{}", display_sources(err.into().as_ref()));
139                            None
140                        });
141
142                        server_tasks_tx.send(fut).await.unwrap();
143                    }
144                    Ok(Message::Response(res)) => {
145                        if let Err(err) = client_responses.send(res).await {
146                            error!("{}", display_sources(&err));
147                            return;
148                        }
149                    }
150                    Err(err) => {
151                        error!("failed to decode message: {}", err);
152                        let res = Response::from_error(Id::Null, to_jsonrpc_error(err));
153                        responses_tx.send(Message::Response(res)).await.unwrap();
154                    }
155                }
156            }
157
158            server_tasks_tx.disconnect();
159            responses_tx.disconnect();
160            client_abort.abort();
161        };
162
163        join!(print_output, read_input, process_server_tasks);
164    }
165}
166
167fn display_sources(error: &dyn std::error::Error) -> String {
168    if let Some(source) = error.source() {
169        format!("{}: {}", error, display_sources(source))
170    } else {
171        error.to_string()
172    }
173}
174
175#[cfg(feature = "runtime-tokio")]
176fn to_jsonrpc_error(err: ParseError) -> Error {
177    match err {
178        ParseError::Body(err) if err.is_data() => Error::invalid_request(),
179        _ => Error::parse_error(),
180    }
181}
182
183#[cfg(feature = "runtime-agnostic")]
184fn to_jsonrpc_error(err: impl std::error::Error) -> Error {
185    match err.source().and_then(|e| e.downcast_ref()) {
186        Some(ParseError::Body(err)) if err.is_data() => Error::invalid_request(),
187        _ => Error::parse_error(),
188    }
189}
190
191#[cfg(test)]
192mod tests {
193    use std::task::{Context, Poll};
194
195    #[cfg(feature = "runtime-agnostic")]
196    use futures::io::Cursor;
197    #[cfg(feature = "runtime-tokio")]
198    use std::io::Cursor;
199
200    use futures::future::Ready;
201    use futures::{future, sink, stream};
202
203    use super::*;
204
205    const REQUEST: &str = r#"{"jsonrpc":"2.0","method":"initialize","params":{},"id":1}"#;
206    const RESPONSE: &str = r#"{"jsonrpc":"2.0","result":{"capabilities":{}},"id":1}"#;
207
208    #[derive(Debug)]
209    struct MockService;
210
211    impl Service<Request> for MockService {
212        type Response = Option<Response>;
213        type Error = String;
214        type Future = Ready<Result<Self::Response, Self::Error>>;
215
216        fn poll_ready(&mut self, _: &mut Context) -> Poll<Result<(), Self::Error>> {
217            Poll::Ready(Ok(()))
218        }
219
220        fn call(&mut self, _: Request) -> Self::Future {
221            let response = serde_json::from_str(RESPONSE).unwrap();
222            future::ok(Some(response))
223        }
224    }
225
226    struct MockLoopback(Vec<Request>);
227
228    impl Loopback for MockLoopback {
229        type RequestStream = stream::Iter<std::vec::IntoIter<Request>>;
230        type ResponseSink = sink::Drain<Response>;
231
232        fn split(self) -> (Self::RequestStream, Self::ResponseSink) {
233            (stream::iter(self.0), sink::drain())
234        }
235    }
236
237    fn mock_request() -> Vec<u8> {
238        format!("Content-Length: {}\r\n\r\n{}", REQUEST.len(), REQUEST).into_bytes()
239    }
240
241    fn mock_response() -> Vec<u8> {
242        format!("Content-Length: {}\r\n\r\n{}", RESPONSE.len(), RESPONSE).into_bytes()
243    }
244
245    fn mock_stdio() -> (Cursor<Vec<u8>>, Vec<u8>) {
246        (Cursor::new(mock_request()), Vec::new())
247    }
248
249    #[tokio::test(flavor = "current_thread")]
250    async fn serves_on_stdio() {
251        let (mut stdin, mut stdout) = mock_stdio();
252        Server::new(&mut stdin, &mut stdout, MockLoopback(vec![]))
253            .serve(MockService)
254            .await;
255
256        assert_eq!(stdin.position(), 80);
257        assert_eq!(stdout, mock_response());
258    }
259
260    #[tokio::test(flavor = "current_thread")]
261    async fn interleaves_messages() {
262        let socket = MockLoopback(vec![serde_json::from_str(REQUEST).unwrap()]);
263
264        let (mut stdin, mut stdout) = mock_stdio();
265        Server::new(&mut stdin, &mut stdout, socket)
266            .serve(MockService)
267            .await;
268
269        assert_eq!(stdin.position(), 80);
270        let output: Vec<_> = mock_request().into_iter().chain(mock_response()).collect();
271        assert_eq!(stdout, output);
272    }
273
274    #[tokio::test(flavor = "current_thread")]
275    async fn handles_invalid_json() {
276        let invalid = r#"{"jsonrpc":"2.0","method":"#;
277        let message = format!("Content-Length: {}\r\n\r\n{}", invalid.len(), invalid).into_bytes();
278        let (mut stdin, mut stdout) = (Cursor::new(message), Vec::new());
279
280        Server::new(&mut stdin, &mut stdout, MockLoopback(vec![]))
281            .serve(MockService)
282            .await;
283
284        assert_eq!(stdin.position(), 48);
285        let err = r#"{"jsonrpc":"2.0","error":{"code":-32700,"message":"Parse error"},"id":null}"#;
286        let output = format!("Content-Length: {}\r\n\r\n{}", err.len(), err).into_bytes();
287        assert_eq!(stdout, output);
288    }
289}