tower_lsp/jsonrpc/
router.rs

1//! Lightweight JSON-RPC router service.
2
3use std::collections::HashMap;
4use std::convert::Infallible;
5use std::fmt::{self, Debug, Formatter};
6use std::future::Future;
7use std::marker::PhantomData;
8use std::sync::Arc;
9use std::task::{Context, Poll};
10
11use futures::future::{self, BoxFuture, FutureExt};
12use serde::{de::DeserializeOwned, Serialize};
13use serde_json::Value;
14use tower::{util::BoxService, Layer, Service};
15
16use crate::jsonrpc::ErrorCode;
17
18use super::{Error, Id, Request, Response};
19
20/// A modular JSON-RPC 2.0 request router service.
21pub struct Router<S, E = Infallible> {
22    server: Arc<S>,
23    methods: HashMap<&'static str, BoxService<Request, Option<Response>, E>>,
24}
25
26impl<S: Send + Sync + 'static, E> Router<S, E> {
27    /// Creates a new `Router` with the given shared state.
28    pub fn new(server: S) -> Self {
29        Router {
30            server: Arc::new(server),
31            methods: HashMap::new(),
32        }
33    }
34
35    /// Returns a reference to the inner server.
36    pub fn inner(&self) -> &S {
37        self.server.as_ref()
38    }
39
40    /// Registers a new RPC method which constructs a response with the given `callback`.
41    ///
42    /// The `layer` argument can be used to inject middleware into the method handler, if desired.
43    pub fn method<P, R, F, L>(&mut self, name: &'static str, callback: F, layer: L) -> &mut Self
44    where
45        P: FromParams,
46        R: IntoResponse,
47        F: for<'a> Method<&'a S, P, R> + Clone + Send + Sync + 'static,
48        L: Layer<MethodHandler<P, R, E>>,
49        L::Service: Service<Request, Response = Option<Response>, Error = E> + Send + 'static,
50        <L::Service as Service<Request>>::Future: Send + 'static,
51    {
52        let server = &self.server;
53        self.methods.entry(name).or_insert_with(|| {
54            let server = server.clone();
55            let handler = MethodHandler::new(move |params| {
56                let callback = callback.clone();
57                let server = server.clone();
58                async move { callback.invoke(&*server, params).await }
59            });
60
61            BoxService::new(layer.layer(handler))
62        });
63
64        self
65    }
66}
67
68impl<S: Debug, E> Debug for Router<S, E> {
69    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
70        f.debug_struct("Router")
71            .field("server", &self.server)
72            .field("methods", &self.methods.keys())
73            .finish()
74    }
75}
76
77impl<S, E: Send + 'static> Service<Request> for Router<S, E> {
78    type Response = Option<Response>;
79    type Error = E;
80    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
81
82    fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
83        Poll::Ready(Ok(()))
84    }
85
86    fn call(&mut self, req: Request) -> Self::Future {
87        if let Some(handler) = self.methods.get_mut(req.method()) {
88            handler.call(req)
89        } else {
90            let (method, id, _) = req.into_parts();
91            future::ok(id.map(|id| {
92                let mut error = Error::method_not_found();
93                error.data = Some(Value::from(method));
94                Response::from_error(id, error)
95            }))
96            .boxed()
97        }
98    }
99}
100
101/// Opaque JSON-RPC method handler.
102pub struct MethodHandler<P, R, E> {
103    f: Box<dyn Fn(P) -> BoxFuture<'static, R> + Send>,
104    _marker: PhantomData<E>,
105}
106
107impl<P: FromParams, R: IntoResponse, E> MethodHandler<P, R, E> {
108    fn new<F, Fut>(handler: F) -> Self
109    where
110        F: Fn(P) -> Fut + Send + 'static,
111        Fut: Future<Output = R> + Send + 'static,
112    {
113        MethodHandler {
114            f: Box::new(move |p| handler(p).boxed()),
115            _marker: PhantomData,
116        }
117    }
118}
119
120impl<P, R, E> Service<Request> for MethodHandler<P, R, E>
121where
122    P: FromParams,
123    R: IntoResponse,
124    E: Send + 'static,
125{
126    type Response = Option<Response>;
127    type Error = E;
128    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
129
130    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
131        Poll::Ready(Ok(()))
132    }
133
134    fn call(&mut self, req: Request) -> Self::Future {
135        let (_, id, params) = req.into_parts();
136
137        match id {
138            Some(_) if R::is_notification() => return future::ok(().into_response(id)).boxed(),
139            None if !R::is_notification() => return future::ok(None).boxed(),
140            _ => {}
141        }
142
143        let params = match P::from_params(params) {
144            Ok(params) => params,
145            Err(err) => return future::ok(id.map(|id| Response::from_error(id, err))).boxed(),
146        };
147
148        (self.f)(params)
149            .map(move |r| Ok(r.into_response(id)))
150            .boxed()
151    }
152}
153
154/// A trait implemented by all valid JSON-RPC method handlers.
155///
156/// This trait abstracts over the following classes of functions and/or closures:
157///
158/// Signature                                            | Description
159/// -----------------------------------------------------|---------------------------------
160/// `async fn f(&self) -> jsonrpc::Result<R>`            | Request without parameters
161/// `async fn f(&self, params: P) -> jsonrpc::Result<R>` | Request with required parameters
162/// `async fn f(&self)`                                  | Notification without parameters
163/// `async fn f(&self, params: P)`                       | Notification with parameters
164pub trait Method<S, P, R>: private::Sealed {
165    /// The future response value.
166    type Future: Future<Output = R> + Send;
167
168    /// Invokes the method with the given `server` receiver and parameters.
169    fn invoke(&self, server: S, params: P) -> Self::Future;
170}
171
172/// Support parameter-less JSON-RPC methods.
173impl<F, S, R, Fut> Method<S, (), R> for F
174where
175    F: Fn(S) -> Fut,
176    Fut: Future<Output = R> + Send,
177{
178    type Future = Fut;
179
180    #[inline]
181    fn invoke(&self, server: S, _: ()) -> Self::Future {
182        self(server)
183    }
184}
185
186/// Support JSON-RPC methods with `params`.
187impl<F, S, P, R, Fut> Method<S, (P,), R> for F
188where
189    F: Fn(S, P) -> Fut,
190    P: DeserializeOwned,
191    Fut: Future<Output = R> + Send,
192{
193    type Future = Fut;
194
195    #[inline]
196    fn invoke(&self, server: S, params: (P,)) -> Self::Future {
197        self(server, params.0)
198    }
199}
200
201/// A trait implemented by all JSON-RPC method parameters.
202pub trait FromParams: private::Sealed + Send + Sized + 'static {
203    /// Attempts to deserialize `Self` from the `params` value extracted from [`Request`].
204    fn from_params(params: Option<Value>) -> super::Result<Self>;
205}
206
207/// Deserialize non-existent JSON-RPC parameters.
208impl FromParams for () {
209    fn from_params(params: Option<Value>) -> super::Result<Self> {
210        if let Some(p) = params {
211            Err(Error::invalid_params(format!("Unexpected params: {p}")))
212        } else {
213            Ok(())
214        }
215    }
216}
217
218/// Deserialize required JSON-RPC parameters.
219impl<P: DeserializeOwned + Send + 'static> FromParams for (P,) {
220    fn from_params(params: Option<Value>) -> super::Result<Self> {
221        if let Some(p) = params {
222            serde_json::from_value(p)
223                .map(|params| (params,))
224                .map_err(|e| Error::invalid_params(e.to_string()))
225        } else {
226            Err(Error::invalid_params("Missing params field"))
227        }
228    }
229}
230
231/// A trait implemented by all JSON-RPC response types.
232pub trait IntoResponse: private::Sealed + Send + 'static {
233    /// Attempts to construct a [`Response`] using `Self` and a corresponding [`Id`].
234    fn into_response(self, id: Option<Id>) -> Option<Response>;
235
236    /// Returns `true` if this is a notification response type.
237    fn is_notification() -> bool;
238}
239
240/// Support JSON-RPC notification methods.
241impl IntoResponse for () {
242    fn into_response(self, id: Option<Id>) -> Option<Response> {
243        id.map(|id| Response::from_error(id, Error::invalid_request()))
244    }
245
246    #[inline]
247    fn is_notification() -> bool {
248        true
249    }
250}
251
252/// Support JSON-RPC request methods.
253impl<R: Serialize + Send + 'static> IntoResponse for Result<R, Error> {
254    fn into_response(self, id: Option<Id>) -> Option<Response> {
255        debug_assert!(id.is_some(), "Requests always contain an `id` field");
256        if let Some(id) = id {
257            let result = self.and_then(|r| {
258                serde_json::to_value(r).map_err(|e| Error {
259                    code: ErrorCode::InternalError,
260                    message: e.to_string().into(),
261                    data: None,
262                })
263            });
264            Some(Response::from_parts(id, result))
265        } else {
266            None
267        }
268    }
269
270    #[inline]
271    fn is_notification() -> bool {
272        false
273    }
274}
275
276mod private {
277    pub trait Sealed {}
278    impl<T> Sealed for T {}
279}
280
281#[cfg(test)]
282mod tests {
283    use serde::{Deserialize, Serialize};
284    use serde_json::json;
285    use tower::layer::layer_fn;
286    use tower::ServiceExt;
287
288    use super::*;
289
290    #[derive(Deserialize, Serialize)]
291    struct Params {
292        foo: i32,
293        bar: String,
294    }
295
296    struct Mock;
297
298    impl Mock {
299        async fn request(&self) -> Result<Value, Error> {
300            Ok(Value::Null)
301        }
302
303        async fn request_params(&self, params: Params) -> Result<Params, Error> {
304            Ok(params)
305        }
306
307        async fn notification(&self) {}
308
309        async fn notification_params(&self, _params: Params) {}
310    }
311
312    #[tokio::test(flavor = "current_thread")]
313    async fn routes_requests() {
314        let mut router: Router<Mock> = Router::new(Mock);
315        router
316            .method("first", Mock::request, layer_fn(|s| s))
317            .method("second", Mock::request_params, layer_fn(|s| s));
318
319        let request = Request::build("first").id(0).finish();
320        let response = router.ready().await.unwrap().call(request).await;
321        assert_eq!(response, Ok(Some(Response::from_ok(0.into(), Value::Null))));
322
323        let params = json!({"foo": -123i32, "bar": "hello world"});
324        let with_params = Request::build("second")
325            .params(params.clone())
326            .id(1)
327            .finish();
328        let response = router.ready().await.unwrap().call(with_params).await;
329        assert_eq!(response, Ok(Some(Response::from_ok(1.into(), params))));
330    }
331
332    #[tokio::test(flavor = "current_thread")]
333    async fn routes_notifications() {
334        let mut router: Router<Mock> = Router::new(Mock);
335        router
336            .method("first", Mock::notification, layer_fn(|s| s))
337            .method("second", Mock::notification_params, layer_fn(|s| s));
338
339        let request = Request::build("first").finish();
340        let response = router.ready().await.unwrap().call(request).await;
341        assert_eq!(response, Ok(None));
342
343        let params = json!({"foo": -123i32, "bar": "hello world"});
344        let with_params = Request::build("second").params(params).finish();
345        let response = router.ready().await.unwrap().call(with_params).await;
346        assert_eq!(response, Ok(None));
347    }
348
349    #[tokio::test(flavor = "current_thread")]
350    async fn rejects_request_with_invalid_params() {
351        let mut router: Router<Mock> = Router::new(Mock);
352        router.method("request", Mock::request_params, layer_fn(|s| s));
353
354        let invalid_params = Request::build("request")
355            .params(json!("wrong"))
356            .id(0)
357            .finish();
358
359        let response = router.ready().await.unwrap().call(invalid_params).await;
360        assert_eq!(
361            response,
362            Ok(Some(Response::from_error(
363                0.into(),
364                Error::invalid_params("invalid type: string \"wrong\", expected struct Params"),
365            )))
366        );
367    }
368
369    #[tokio::test(flavor = "current_thread")]
370    async fn ignores_notification_with_invalid_params() {
371        let mut router: Router<Mock> = Router::new(Mock);
372        router.method("notification", Mock::request_params, layer_fn(|s| s));
373
374        let invalid_params = Request::build("notification")
375            .params(json!("wrong"))
376            .finish();
377
378        let response = router.ready().await.unwrap().call(invalid_params).await;
379        assert_eq!(response, Ok(None));
380    }
381
382    #[tokio::test(flavor = "current_thread")]
383    async fn handles_incorrect_request_types() {
384        let mut router: Router<Mock> = Router::new(Mock);
385        router
386            .method("first", Mock::request, layer_fn(|s| s))
387            .method("second", Mock::notification, layer_fn(|s| s));
388
389        let request = Request::build("first").finish();
390        let response = router.ready().await.unwrap().call(request).await;
391        assert_eq!(response, Ok(None));
392
393        let request = Request::build("second").id(0).finish();
394        let response = router.ready().await.unwrap().call(request).await;
395        assert_eq!(
396            response,
397            Ok(Some(Response::from_error(
398                0.into(),
399                Error::invalid_request(),
400            )))
401        );
402    }
403
404    #[tokio::test(flavor = "current_thread")]
405    async fn responds_to_nonexistent_request() {
406        let mut router: Router<Mock> = Router::new(Mock);
407
408        let request = Request::build("foo").id(0).finish();
409        let response = router.ready().await.unwrap().call(request).await;
410        let mut error = Error::method_not_found();
411        error.data = Some("foo".into());
412        assert_eq!(response, Ok(Some(Response::from_error(0.into(), error))));
413    }
414
415    #[tokio::test(flavor = "current_thread")]
416    async fn ignores_nonexistent_notification() {
417        let mut router: Router<Mock> = Router::new(Mock);
418
419        let request = Request::build("foo").finish();
420        let response = router.ready().await.unwrap().call(request).await;
421        assert_eq!(response, Ok(None));
422    }
423}