1use std::marker::PhantomData;
4use std::sync::Arc;
5use std::task::{Context, Poll};
6
7use futures::future::{self, BoxFuture, FutureExt};
8use tower::{Layer, Service};
9use tracing::{info, warn};
10
11use super::ExitedError;
12use crate::jsonrpc::{not_initialized_error, Error, Id, Request, Response};
13
14use super::client::Client;
15use super::pending::Pending;
16use super::state::{ServerState, State};
17
18pub struct Initialize {
24 state: Arc<ServerState>,
25 pending: Arc<Pending>,
26}
27
28impl Initialize {
29 pub fn new(state: Arc<ServerState>, pending: Arc<Pending>) -> Self {
30 Initialize { state, pending }
31 }
32}
33
34impl<S> Layer<S> for Initialize {
35 type Service = InitializeService<S>;
36
37 fn layer(&self, inner: S) -> Self::Service {
38 InitializeService {
39 inner: Cancellable::new(inner, self.pending.clone()),
40 state: self.state.clone(),
41 }
42 }
43}
44
45pub struct InitializeService<S> {
47 inner: Cancellable<S>,
48 state: Arc<ServerState>,
49}
50
51impl<S> Service<Request> for InitializeService<S>
52where
53 S: Service<Request, Response = Option<Response>, Error = ExitedError>,
54 S::Future: Send + 'static,
55{
56 type Response = S::Response;
57 type Error = S::Error;
58 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
59
60 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
61 self.inner.poll_ready(cx)
62 }
63
64 fn call(&mut self, req: Request) -> Self::Future {
65 if self.state.get() == State::Uninitialized {
66 let state = self.state.clone();
67 let fut = self.inner.call(req);
68
69 Box::pin(async move {
70 let response = fut.await?;
71
72 match &response {
73 Some(res) if res.is_ok() => state.set(State::Initialized),
74 _ => state.set(State::Uninitialized),
75 }
76
77 Ok(response)
78 })
79 } else {
80 warn!("received duplicate `initialize` request, ignoring");
81 let (_, id, _) = req.into_parts();
82 future::ok(id.map(|id| Response::from_error(id, Error::invalid_request()))).boxed()
83 }
84 }
85}
86
87pub struct Shutdown {
93 state: Arc<ServerState>,
94 pending: Arc<Pending>,
95}
96
97impl Shutdown {
98 pub fn new(state: Arc<ServerState>, pending: Arc<Pending>) -> Self {
99 Shutdown { state, pending }
100 }
101}
102
103impl<S> Layer<S> for Shutdown {
104 type Service = ShutdownService<S>;
105
106 fn layer(&self, inner: S) -> Self::Service {
107 ShutdownService {
108 inner: Cancellable::new(inner, self.pending.clone()),
109 state: self.state.clone(),
110 }
111 }
112}
113
114pub struct ShutdownService<S> {
116 inner: Cancellable<S>,
117 state: Arc<ServerState>,
118}
119
120impl<S> Service<Request> for ShutdownService<S>
121where
122 S: Service<Request, Response = Option<Response>, Error = ExitedError>,
123 S::Future: Into<BoxFuture<'static, Result<Option<Response>, S::Error>>> + Send + 'static,
124{
125 type Response = S::Response;
126 type Error = S::Error;
127 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
128
129 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
130 self.inner.poll_ready(cx)
131 }
132
133 fn call(&mut self, req: Request) -> Self::Future {
134 match self.state.get() {
135 State::Initialized => {
136 info!("shutdown request received, shutting down");
137 self.state.set(State::ShutDown);
138 self.inner.call(req)
139 }
140 cur_state => {
141 let (_, id, _) = req.into_parts();
142 future::ok(not_initialized_response(id, cur_state)).boxed()
143 }
144 }
145 }
146}
147
148pub struct Exit {
154 state: Arc<ServerState>,
155 pending: Arc<Pending>,
156 client: Client,
157}
158
159impl Exit {
160 pub fn new(state: Arc<ServerState>, pending: Arc<Pending>, client: Client) -> Self {
161 Exit {
162 state,
163 pending,
164 client,
165 }
166 }
167}
168
169impl<S> Layer<S> for Exit {
170 type Service = ExitService<S>;
171
172 fn layer(&self, _: S) -> Self::Service {
173 ExitService {
174 state: self.state.clone(),
175 pending: self.pending.clone(),
176 client: self.client.clone(),
177 _marker: PhantomData,
178 }
179 }
180}
181
182pub struct ExitService<S> {
184 state: Arc<ServerState>,
185 pending: Arc<Pending>,
186 client: Client,
187 _marker: PhantomData<S>,
188}
189
190impl<S> Service<Request> for ExitService<S> {
191 type Response = Option<Response>;
192 type Error = ExitedError;
193 type Future = future::Ready<Result<Self::Response, Self::Error>>;
194
195 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
196 if self.state.get() == State::Exited {
197 Poll::Ready(Err(ExitedError(())))
198 } else {
199 Poll::Ready(Ok(()))
200 }
201 }
202
203 fn call(&mut self, _: Request) -> Self::Future {
204 info!("exit notification received, stopping");
205 self.state.set(State::Exited);
206 self.pending.cancel_all();
207 self.client.close();
208 future::ok(None)
209 }
210}
211
212pub struct Normal {
214 state: Arc<ServerState>,
215 pending: Arc<Pending>,
216}
217
218impl Normal {
219 pub fn new(state: Arc<ServerState>, pending: Arc<Pending>) -> Self {
220 Normal { state, pending }
221 }
222}
223
224impl<S> Layer<S> for Normal {
225 type Service = NormalService<S>;
226
227 fn layer(&self, inner: S) -> Self::Service {
228 NormalService {
229 inner: Cancellable::new(inner, self.pending.clone()),
230 state: self.state.clone(),
231 }
232 }
233}
234
235pub struct NormalService<S> {
237 inner: Cancellable<S>,
238 state: Arc<ServerState>,
239}
240
241impl<S> Service<Request> for NormalService<S>
242where
243 S: Service<Request, Response = Option<Response>, Error = ExitedError>,
244 S::Future: Into<BoxFuture<'static, Result<Option<Response>, S::Error>>> + Send + 'static,
245{
246 type Response = S::Response;
247 type Error = S::Error;
248 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
249
250 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
251 self.inner.poll_ready(cx)
252 }
253
254 fn call(&mut self, req: Request) -> Self::Future {
255 match self.state.get() {
256 State::Initialized => self.inner.call(req),
257 cur_state => {
258 let (_, id, _) = req.into_parts();
259 future::ok(not_initialized_response(id, cur_state)).boxed()
260 }
261 }
262 }
263}
264
265struct Cancellable<S> {
271 inner: S,
272 pending: Arc<Pending>,
273}
274
275impl<S> Cancellable<S> {
276 fn new(inner: S, pending: Arc<Pending>) -> Self {
277 Cancellable { inner, pending }
278 }
279}
280
281impl<S> Service<Request> for Cancellable<S>
282where
283 S: Service<Request, Response = Option<Response>, Error = ExitedError>,
284 S::Future: Send + 'static,
285{
286 type Response = S::Response;
287 type Error = S::Error;
288 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
289
290 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
291 self.inner.poll_ready(cx)
292 }
293
294 fn call(&mut self, req: Request) -> Self::Future {
295 match req.id().cloned() {
296 Some(id) => self.pending.execute(id, self.inner.call(req)).boxed(),
297 None => self.inner.call(req).boxed(),
298 }
299 }
300}
301
302fn not_initialized_response(id: Option<Id>, server_state: State) -> Option<Response> {
303 let id = id?;
304 let error = match server_state {
305 State::Uninitialized | State::Initializing => not_initialized_error(),
306 _ => Error::invalid_request(),
307 };
308
309 Some(Response::from_error(id, error))
310}
311
312