1use std::collections::HashMap;
4use std::convert::Infallible;
5use std::fmt::{self, Debug, Formatter};
6use std::future::Future;
7use std::marker::PhantomData;
8use std::sync::Arc;
9use std::task::{Context, Poll};
10
11use futures::future::{self, BoxFuture, FutureExt};
12use serde::{de::DeserializeOwned, Serialize};
13use serde_json::Value;
14use tower::{util::BoxService, Layer, Service};
15
16use crate::jsonrpc::ErrorCode;
17
18use super::{Error, Id, Request, Response};
19
20pub struct Router<S, E = Infallible> {
22 server: Arc<S>,
23 methods: HashMap<&'static str, BoxService<Request, Option<Response>, E>>,
24}
25
26impl<S: Send + Sync + 'static, E> Router<S, E> {
27 pub fn new(server: S) -> Self {
29 Router {
30 server: Arc::new(server),
31 methods: HashMap::new(),
32 }
33 }
34
35 pub fn inner(&self) -> &S {
37 self.server.as_ref()
38 }
39
40 pub fn method<P, R, F, L>(&mut self, name: &'static str, callback: F, layer: L) -> &mut Self
44 where
45 P: FromParams,
46 R: IntoResponse,
47 F: for<'a> Method<&'a S, P, R> + Clone + Send + Sync + 'static,
48 L: Layer<MethodHandler<P, R, E>>,
49 L::Service: Service<Request, Response = Option<Response>, Error = E> + Send + 'static,
50 <L::Service as Service<Request>>::Future: Send + 'static,
51 {
52 let server = &self.server;
53 self.methods.entry(name).or_insert_with(|| {
54 let server = server.clone();
55 let handler = MethodHandler::new(move |params| {
56 let callback = callback.clone();
57 let server = server.clone();
58 async move { callback.invoke(&*server, params).await }
59 });
60
61 BoxService::new(layer.layer(handler))
62 });
63
64 self
65 }
66}
67
68impl<S: Debug, E> Debug for Router<S, E> {
69 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
70 f.debug_struct("Router")
71 .field("server", &self.server)
72 .field("methods", &self.methods.keys())
73 .finish()
74 }
75}
76
77impl<S, E: Send + 'static> Service<Request> for Router<S, E> {
78 type Response = Option<Response>;
79 type Error = E;
80 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
81
82 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
83 Poll::Ready(Ok(()))
84 }
85
86 fn call(&mut self, req: Request) -> Self::Future {
87 if let Some(handler) = self.methods.get_mut(req.method()) {
88 handler.call(req)
89 } else {
90 let (method, id, _) = req.into_parts();
91 future::ok(id.map(|id| {
92 let mut error = Error::method_not_found();
93 error.data = Some(Value::from(method));
94 Response::from_error(id, error)
95 }))
96 .boxed()
97 }
98 }
99}
100
101pub struct MethodHandler<P, R, E> {
103 f: Box<dyn Fn(P) -> BoxFuture<'static, R> + Send>,
104 _marker: PhantomData<E>,
105}
106
107impl<P: FromParams, R: IntoResponse, E> MethodHandler<P, R, E> {
108 fn new<F, Fut>(handler: F) -> Self
109 where
110 F: Fn(P) -> Fut + Send + 'static,
111 Fut: Future<Output = R> + Send + 'static,
112 {
113 MethodHandler {
114 f: Box::new(move |p| handler(p).boxed()),
115 _marker: PhantomData,
116 }
117 }
118}
119
120impl<P, R, E> Service<Request> for MethodHandler<P, R, E>
121where
122 P: FromParams,
123 R: IntoResponse,
124 E: Send + 'static,
125{
126 type Response = Option<Response>;
127 type Error = E;
128 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
129
130 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
131 Poll::Ready(Ok(()))
132 }
133
134 fn call(&mut self, req: Request) -> Self::Future {
135 let (_, id, params) = req.into_parts();
136
137 match id {
138 Some(_) if R::is_notification() => return future::ok(().into_response(id)).boxed(),
139 None if !R::is_notification() => return future::ok(None).boxed(),
140 _ => {}
141 }
142
143 let params = match P::from_params(params) {
144 Ok(params) => params,
145 Err(err) => return future::ok(id.map(|id| Response::from_error(id, err))).boxed(),
146 };
147
148 (self.f)(params)
149 .map(move |r| Ok(r.into_response(id)))
150 .boxed()
151 }
152}
153
154pub trait Method<S, P, R>: private::Sealed {
165 type Future: Future<Output = R> + Send;
167
168 fn invoke(&self, server: S, params: P) -> Self::Future;
170}
171
172impl<F, S, R, Fut> Method<S, (), R> for F
174where
175 F: Fn(S) -> Fut,
176 Fut: Future<Output = R> + Send,
177{
178 type Future = Fut;
179
180 #[inline]
181 fn invoke(&self, server: S, _: ()) -> Self::Future {
182 self(server)
183 }
184}
185
186impl<F, S, P, R, Fut> Method<S, (P,), R> for F
188where
189 F: Fn(S, P) -> Fut,
190 P: DeserializeOwned,
191 Fut: Future<Output = R> + Send,
192{
193 type Future = Fut;
194
195 #[inline]
196 fn invoke(&self, server: S, params: (P,)) -> Self::Future {
197 self(server, params.0)
198 }
199}
200
201pub trait FromParams: private::Sealed + Send + Sized + 'static {
203 fn from_params(params: Option<Value>) -> super::Result<Self>;
205}
206
207impl FromParams for () {
209 fn from_params(params: Option<Value>) -> super::Result<Self> {
210 if let Some(p) = params {
211 Err(Error::invalid_params(format!("Unexpected params: {p}")))
212 } else {
213 Ok(())
214 }
215 }
216}
217
218impl<P: DeserializeOwned + Send + 'static> FromParams for (P,) {
220 fn from_params(params: Option<Value>) -> super::Result<Self> {
221 if let Some(p) = params {
222 serde_json::from_value(p)
223 .map(|params| (params,))
224 .map_err(|e| Error::invalid_params(e.to_string()))
225 } else {
226 Err(Error::invalid_params("Missing params field"))
227 }
228 }
229}
230
231pub trait IntoResponse: private::Sealed + Send + 'static {
233 fn into_response(self, id: Option<Id>) -> Option<Response>;
235
236 fn is_notification() -> bool;
238}
239
240impl IntoResponse for () {
242 fn into_response(self, id: Option<Id>) -> Option<Response> {
243 id.map(|id| Response::from_error(id, Error::invalid_request()))
244 }
245
246 #[inline]
247 fn is_notification() -> bool {
248 true
249 }
250}
251
252impl<R: Serialize + Send + 'static> IntoResponse for Result<R, Error> {
254 fn into_response(self, id: Option<Id>) -> Option<Response> {
255 debug_assert!(id.is_some(), "Requests always contain an `id` field");
256 if let Some(id) = id {
257 let result = self.and_then(|r| {
258 serde_json::to_value(r).map_err(|e| Error {
259 code: ErrorCode::InternalError,
260 message: e.to_string().into(),
261 data: None,
262 })
263 });
264 Some(Response::from_parts(id, result))
265 } else {
266 None
267 }
268 }
269
270 #[inline]
271 fn is_notification() -> bool {
272 false
273 }
274}
275
276mod private {
277 pub trait Sealed {}
278 impl<T> Sealed for T {}
279}
280
281#[cfg(test)]
282mod tests {
283 use serde::{Deserialize, Serialize};
284 use serde_json::json;
285 use tower::layer::layer_fn;
286 use tower::ServiceExt;
287
288 use super::*;
289
290 #[derive(Deserialize, Serialize)]
291 struct Params {
292 foo: i32,
293 bar: String,
294 }
295
296 struct Mock;
297
298 impl Mock {
299 async fn request(&self) -> Result<Value, Error> {
300 Ok(Value::Null)
301 }
302
303 async fn request_params(&self, params: Params) -> Result<Params, Error> {
304 Ok(params)
305 }
306
307 async fn notification(&self) {}
308
309 async fn notification_params(&self, _params: Params) {}
310 }
311
312 #[tokio::test(flavor = "current_thread")]
313 async fn routes_requests() {
314 let mut router: Router<Mock> = Router::new(Mock);
315 router
316 .method("first", Mock::request, layer_fn(|s| s))
317 .method("second", Mock::request_params, layer_fn(|s| s));
318
319 let request = Request::build("first").id(0).finish();
320 let response = router.ready().await.unwrap().call(request).await;
321 assert_eq!(response, Ok(Some(Response::from_ok(0.into(), Value::Null))));
322
323 let params = json!({"foo": -123i32, "bar": "hello world"});
324 let with_params = Request::build("second")
325 .params(params.clone())
326 .id(1)
327 .finish();
328 let response = router.ready().await.unwrap().call(with_params).await;
329 assert_eq!(response, Ok(Some(Response::from_ok(1.into(), params))));
330 }
331
332 #[tokio::test(flavor = "current_thread")]
333 async fn routes_notifications() {
334 let mut router: Router<Mock> = Router::new(Mock);
335 router
336 .method("first", Mock::notification, layer_fn(|s| s))
337 .method("second", Mock::notification_params, layer_fn(|s| s));
338
339 let request = Request::build("first").finish();
340 let response = router.ready().await.unwrap().call(request).await;
341 assert_eq!(response, Ok(None));
342
343 let params = json!({"foo": -123i32, "bar": "hello world"});
344 let with_params = Request::build("second").params(params).finish();
345 let response = router.ready().await.unwrap().call(with_params).await;
346 assert_eq!(response, Ok(None));
347 }
348
349 #[tokio::test(flavor = "current_thread")]
350 async fn rejects_request_with_invalid_params() {
351 let mut router: Router<Mock> = Router::new(Mock);
352 router.method("request", Mock::request_params, layer_fn(|s| s));
353
354 let invalid_params = Request::build("request")
355 .params(json!("wrong"))
356 .id(0)
357 .finish();
358
359 let response = router.ready().await.unwrap().call(invalid_params).await;
360 assert_eq!(
361 response,
362 Ok(Some(Response::from_error(
363 0.into(),
364 Error::invalid_params("invalid type: string \"wrong\", expected struct Params"),
365 )))
366 );
367 }
368
369 #[tokio::test(flavor = "current_thread")]
370 async fn ignores_notification_with_invalid_params() {
371 let mut router: Router<Mock> = Router::new(Mock);
372 router.method("notification", Mock::request_params, layer_fn(|s| s));
373
374 let invalid_params = Request::build("notification")
375 .params(json!("wrong"))
376 .finish();
377
378 let response = router.ready().await.unwrap().call(invalid_params).await;
379 assert_eq!(response, Ok(None));
380 }
381
382 #[tokio::test(flavor = "current_thread")]
383 async fn handles_incorrect_request_types() {
384 let mut router: Router<Mock> = Router::new(Mock);
385 router
386 .method("first", Mock::request, layer_fn(|s| s))
387 .method("second", Mock::notification, layer_fn(|s| s));
388
389 let request = Request::build("first").finish();
390 let response = router.ready().await.unwrap().call(request).await;
391 assert_eq!(response, Ok(None));
392
393 let request = Request::build("second").id(0).finish();
394 let response = router.ready().await.unwrap().call(request).await;
395 assert_eq!(
396 response,
397 Ok(Some(Response::from_error(
398 0.into(),
399 Error::invalid_request(),
400 )))
401 );
402 }
403
404 #[tokio::test(flavor = "current_thread")]
405 async fn responds_to_nonexistent_request() {
406 let mut router: Router<Mock> = Router::new(Mock);
407
408 let request = Request::build("foo").id(0).finish();
409 let response = router.ready().await.unwrap().call(request).await;
410 let mut error = Error::method_not_found();
411 error.data = Some("foo".into());
412 assert_eq!(response, Ok(Some(Response::from_error(0.into(), error))));
413 }
414
415 #[tokio::test(flavor = "current_thread")]
416 async fn ignores_nonexistent_notification() {
417 let mut router: Router<Mock> = Router::new(Mock);
418
419 let request = Request::build("foo").finish();
420 let response = router.ready().await.unwrap().call(request).await;
421 assert_eq!(response, Ok(None));
422 }
423}