1#[cfg(feature = "runtime-agnostic")]
4use async_codec_lite::{FramedRead, FramedWrite};
5#[cfg(feature = "runtime-agnostic")]
6use futures::io::{AsyncRead, AsyncWrite};
7
8#[cfg(feature = "runtime-tokio")]
9use tokio::io::{AsyncRead, AsyncWrite};
10#[cfg(feature = "runtime-tokio")]
11use tokio_util::codec::{FramedRead, FramedWrite};
12
13use futures::channel::mpsc;
14use futures::{future, join, stream, FutureExt, Sink, SinkExt, Stream, StreamExt, TryFutureExt};
15use tower::Service;
16use tracing::error;
17
18use crate::codec::{LanguageServerCodec, ParseError};
19use crate::jsonrpc::{Error, Id, Message, Request, Response};
20use crate::service::{ClientSocket, RequestStream, ResponseSink};
21
22const DEFAULT_MAX_CONCURRENCY: usize = 4;
23const MESSAGE_QUEUE_SIZE: usize = 100;
24
25pub trait Loopback {
29 type RequestStream: Stream<Item = Request>;
31 type ResponseSink: Sink<Response> + Unpin;
33
34 fn split(self) -> (Self::RequestStream, Self::ResponseSink);
38}
39
40impl Loopback for ClientSocket {
41 type RequestStream = RequestStream;
42 type ResponseSink = ResponseSink;
43
44 #[inline]
45 fn split(self) -> (Self::RequestStream, Self::ResponseSink) {
46 self.split()
47 }
48}
49
50#[derive(Debug)]
52pub struct Server<I, O, L = ClientSocket> {
53 stdin: I,
54 stdout: O,
55 loopback: L,
56 max_concurrency: usize,
57}
58
59impl<I, O, L> Server<I, O, L>
60where
61 I: AsyncRead + Unpin,
62 O: AsyncWrite,
63 L: Loopback,
64 <L::ResponseSink as Sink<Response>>::Error: std::error::Error,
65{
66 pub fn new(stdin: I, stdout: O, socket: L) -> Self {
68 Server {
69 stdin,
70 stdout,
71 loopback: socket,
72 max_concurrency: DEFAULT_MAX_CONCURRENCY,
73 }
74 }
75
76 pub fn concurrency_level(mut self, max: usize) -> Self {
97 self.max_concurrency = max;
98 self
99 }
100
101 pub async fn serve<T>(self, mut service: T)
103 where
104 T: Service<Request, Response = Option<Response>> + Send + 'static,
105 T::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
106 T::Future: Send,
107 {
108 let (client_requests, mut client_responses) = self.loopback.split();
109 let (client_requests, client_abort) = stream::abortable(client_requests);
110 let (mut responses_tx, responses_rx) = mpsc::channel(0);
111 let (mut server_tasks_tx, server_tasks_rx) = mpsc::channel(MESSAGE_QUEUE_SIZE);
112
113 let mut framed_stdin = FramedRead::new(self.stdin, LanguageServerCodec::default());
114 let framed_stdout = FramedWrite::new(self.stdout, LanguageServerCodec::default());
115
116 let process_server_tasks = server_tasks_rx
117 .buffer_unordered(self.max_concurrency)
118 .filter_map(future::ready)
119 .map(|res| Ok(Message::Response(res)))
120 .forward(responses_tx.clone().sink_map_err(|_| unreachable!()))
121 .map(|_| ());
122
123 let print_output = stream::select(responses_rx, client_requests.map(Message::Request))
124 .map(Ok)
125 .forward(framed_stdout.sink_map_err(|e| error!("failed to encode message: {}", e)))
126 .map(|_| ());
127
128 let read_input = async {
129 while let Some(msg) = framed_stdin.next().await {
130 match msg {
131 Ok(Message::Request(req)) => {
132 if let Err(err) = future::poll_fn(|cx| service.poll_ready(cx)).await {
133 error!("{}", display_sources(err.into().as_ref()));
134 return;
135 }
136
137 let fut = service.call(req).unwrap_or_else(|err| {
138 error!("{}", display_sources(err.into().as_ref()));
139 None
140 });
141
142 server_tasks_tx.send(fut).await.unwrap();
143 }
144 Ok(Message::Response(res)) => {
145 if let Err(err) = client_responses.send(res).await {
146 error!("{}", display_sources(&err));
147 return;
148 }
149 }
150 Err(err) => {
151 error!("failed to decode message: {}", err);
152 let res = Response::from_error(Id::Null, to_jsonrpc_error(err));
153 responses_tx.send(Message::Response(res)).await.unwrap();
154 }
155 }
156 }
157
158 server_tasks_tx.disconnect();
159 responses_tx.disconnect();
160 client_abort.abort();
161 };
162
163 join!(print_output, read_input, process_server_tasks);
164 }
165}
166
167fn display_sources(error: &dyn std::error::Error) -> String {
168 if let Some(source) = error.source() {
169 format!("{}: {}", error, display_sources(source))
170 } else {
171 error.to_string()
172 }
173}
174
175#[cfg(feature = "runtime-tokio")]
176fn to_jsonrpc_error(err: ParseError) -> Error {
177 match err {
178 ParseError::Body(err) if err.is_data() => Error::invalid_request(),
179 _ => Error::parse_error(),
180 }
181}
182
183#[cfg(feature = "runtime-agnostic")]
184fn to_jsonrpc_error(err: impl std::error::Error) -> Error {
185 match err.source().and_then(|e| e.downcast_ref()) {
186 Some(ParseError::Body(err)) if err.is_data() => Error::invalid_request(),
187 _ => Error::parse_error(),
188 }
189}
190
191#[cfg(test)]
192mod tests {
193 use std::task::{Context, Poll};
194
195 #[cfg(feature = "runtime-agnostic")]
196 use futures::io::Cursor;
197 #[cfg(feature = "runtime-tokio")]
198 use std::io::Cursor;
199
200 use futures::future::Ready;
201 use futures::{future, sink, stream};
202
203 use super::*;
204
205 const REQUEST: &str = r#"{"jsonrpc":"2.0","method":"initialize","params":{},"id":1}"#;
206 const RESPONSE: &str = r#"{"jsonrpc":"2.0","result":{"capabilities":{}},"id":1}"#;
207
208 #[derive(Debug)]
209 struct MockService;
210
211 impl Service<Request> for MockService {
212 type Response = Option<Response>;
213 type Error = String;
214 type Future = Ready<Result<Self::Response, Self::Error>>;
215
216 fn poll_ready(&mut self, _: &mut Context) -> Poll<Result<(), Self::Error>> {
217 Poll::Ready(Ok(()))
218 }
219
220 fn call(&mut self, _: Request) -> Self::Future {
221 let response = serde_json::from_str(RESPONSE).unwrap();
222 future::ok(Some(response))
223 }
224 }
225
226 struct MockLoopback(Vec<Request>);
227
228 impl Loopback for MockLoopback {
229 type RequestStream = stream::Iter<std::vec::IntoIter<Request>>;
230 type ResponseSink = sink::Drain<Response>;
231
232 fn split(self) -> (Self::RequestStream, Self::ResponseSink) {
233 (stream::iter(self.0), sink::drain())
234 }
235 }
236
237 fn mock_request() -> Vec<u8> {
238 format!("Content-Length: {}\r\n\r\n{}", REQUEST.len(), REQUEST).into_bytes()
239 }
240
241 fn mock_response() -> Vec<u8> {
242 format!("Content-Length: {}\r\n\r\n{}", RESPONSE.len(), RESPONSE).into_bytes()
243 }
244
245 fn mock_stdio() -> (Cursor<Vec<u8>>, Vec<u8>) {
246 (Cursor::new(mock_request()), Vec::new())
247 }
248
249 #[tokio::test(flavor = "current_thread")]
250 async fn serves_on_stdio() {
251 let (mut stdin, mut stdout) = mock_stdio();
252 Server::new(&mut stdin, &mut stdout, MockLoopback(vec![]))
253 .serve(MockService)
254 .await;
255
256 assert_eq!(stdin.position(), 80);
257 assert_eq!(stdout, mock_response());
258 }
259
260 #[tokio::test(flavor = "current_thread")]
261 async fn interleaves_messages() {
262 let socket = MockLoopback(vec![serde_json::from_str(REQUEST).unwrap()]);
263
264 let (mut stdin, mut stdout) = mock_stdio();
265 Server::new(&mut stdin, &mut stdout, socket)
266 .serve(MockService)
267 .await;
268
269 assert_eq!(stdin.position(), 80);
270 let output: Vec<_> = mock_request().into_iter().chain(mock_response()).collect();
271 assert_eq!(stdout, output);
272 }
273
274 #[tokio::test(flavor = "current_thread")]
275 async fn handles_invalid_json() {
276 let invalid = r#"{"jsonrpc":"2.0","method":"#;
277 let message = format!("Content-Length: {}\r\n\r\n{}", invalid.len(), invalid).into_bytes();
278 let (mut stdin, mut stdout) = (Cursor::new(message), Vec::new());
279
280 Server::new(&mut stdin, &mut stdout, MockLoopback(vec![]))
281 .serve(MockService)
282 .await;
283
284 assert_eq!(stdin.position(), 48);
285 let err = r#"{"jsonrpc":"2.0","error":{"code":-32700,"message":"Parse error"},"id":null}"#;
286 let output = format!("Content-Length: {}\r\n\r\n{}", err.len(), err).into_bytes();
287 assert_eq!(stdout, output);
288 }
289}