1use std::collections::BTreeMap;
20use std::io::{BufRead, Write};
21use std::sync::{Arc, Mutex};
22
23use anyhow::anyhow;
24use async_trait::async_trait;
25use mz_persist::location::ExternalError;
26use mz_persist_client::ShardId;
27use serde_json::Value;
28use tokio::sync::oneshot;
29use tracing::{Instrument, debug_span, info, trace};
30
31use crate::maelstrom::Args;
32use crate::maelstrom::api::{Body, ErrorCode, MaelstromError, Msg, MsgId, NodeId};
33
34#[async_trait]
38pub trait Service: Sized + Send + Sync {
39 async fn init(args: &Args, handle: &Handle) -> Result<Self, MaelstromError>;
43
44 async fn eval(&self, handle: Handle, src: NodeId, req: Body);
49}
50
51pub fn run<R, W, S>(args: Args, read: R, write: W) -> Result<(), anyhow::Error>
54where
55 R: BufRead,
56 W: Write + Send + Sync + 'static,
57 S: Service + 'static,
58{
59 let mut node = Node::<S>::new(args, write);
60 for line in read.lines() {
61 let line = line.map_err(|err| anyhow!("req read failed: {}", err))?;
62 trace!("raw: [{}]", line);
63 let req: Msg = line
64 .parse()
65 .map_err(|err| anyhow!("invalid req {}: {}", err, line))?;
66 if req.body.in_reply_to().is_none() {
67 info!("req: {}", req);
68 } else {
69 trace!("req: {}", req);
70 }
71 node.handle(req);
72 }
73 Ok(())
74}
75
76struct Node<S>
77where
78 S: Service + 'static,
79{
80 args: Args,
81 core: Arc<Mutex<Core>>,
82 node_id: Option<NodeId>,
83 service: Arc<AsyncInitOnceWaitable<Arc<S>>>,
84}
85
86impl<S> Node<S>
87where
88 S: Service + 'static,
89{
90 fn new<W>(args: Args, write: W) -> Self
91 where
92 W: Write + Send + Sync + 'static,
93 {
94 let core = Core {
95 write: Box::new(write),
96 next_msg_id: MsgId(0),
97 callbacks: BTreeMap::new(),
98 };
99 Node {
100 args,
101 core: Arc::new(Mutex::new(core)),
102 node_id: None,
103 service: Arc::new(AsyncInitOnceWaitable::new()),
104 }
105 }
106
107 pub fn handle(&mut self, msg: Msg) {
108 if let Some(node_id) = self.node_id.as_ref() {
111 if node_id != &msg.dest {
113 return;
114 }
115
116 let handle = Handle {
117 node_id: node_id.clone(),
118 core: Arc::clone(&self.core),
119 };
120
121 let body = match handle.maybe_handle_service_res(&msg.src, msg.body) {
122 Ok(()) => return,
123 Err(x) => x,
124 };
125
126 let service = Arc::clone(&self.service);
127 mz_ore::task::spawn(
128 || "maelstrom::handle".to_string(),
129 async move {
130 let service = service.get().await;
131 let () = service.eval(handle, msg.src, body).await;
132 }
133 .instrument(debug_span!("maelstrom::handle")),
134 );
135 return;
136 }
137
138 match msg.body {
141 Body::ReqInit {
142 msg_id, node_id, ..
143 } => {
144 self.node_id = Some(node_id.clone());
146 let handle = Handle {
147 node_id,
148 core: Arc::clone(&self.core),
149 };
150
151 let in_reply_to = msg_id;
158 handle.send_res(msg.src, move |msg_id| Body::ResInit {
159 msg_id,
160 in_reply_to,
161 });
162
163 let args = self.args.clone();
169 let service_init = Arc::clone(&self.service);
170 mz_ore::task::spawn(
171 || "maelstrom::init".to_string(),
172 async move {
173 let service = match S::init(&args, &handle).await {
174 Ok(x) => x,
175 Err(err) => {
176 panic!("service initialization failed: {}", err);
180 }
181 };
182 service_init.init_once(Arc::new(service)).await;
183 }
184 .instrument(debug_span!("maelstrom::init")),
185 );
186 }
187 _ => {}
189 }
190 }
191}
192
193struct Core {
194 write: Box<dyn Write + Send + Sync>,
195 next_msg_id: MsgId,
196 callbacks: BTreeMap<MsgId, oneshot::Sender<Body>>,
197}
198
199impl std::fmt::Debug for Core {
200 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
201 let Core {
203 write: _,
204 next_msg_id,
205 callbacks,
206 } = self;
207 f.debug_struct("Core")
208 .field("next_msg_id", &next_msg_id)
209 .field("callbacks", &callbacks.keys().collect::<Vec<_>>())
210 .finish_non_exhaustive()
211 }
212}
213
214impl Core {
215 fn alloc_msg_id(&mut self) -> MsgId {
216 self.next_msg_id = self.next_msg_id.next();
217 self.next_msg_id
218 }
219}
220
221#[derive(Debug, Clone)]
223pub struct Handle {
224 node_id: NodeId,
225 core: Arc<Mutex<Core>>,
226}
227
228impl Handle {
229 pub fn node_id(&self) -> NodeId {
231 self.node_id.clone()
232 }
233
234 pub fn send_res<BodyFn: FnOnce(MsgId) -> Body>(&self, dest: NodeId, res_fn: BodyFn) {
239 let mut core = self.core.lock().expect("mutex poisoned");
240 let msg_id = core.alloc_msg_id();
241 let res = Msg {
242 src: self.node_id.clone(),
243 dest,
244 body: res_fn(msg_id),
245 };
246 info!("res: {}", res);
247 write!(core.write.as_mut(), "{}\n", res).expect("res write failed");
248 }
249
250 pub async fn send_service_req<BodyFn: FnOnce(MsgId) -> Body>(
255 &self,
256 dest: NodeId,
257 req_fn: BodyFn,
258 ) -> Body {
259 let (tx, rx) = oneshot::channel();
260 {
261 let mut core = self.core.lock().expect("mutex poisoned");
262 let msg_id = core.alloc_msg_id();
263 core.callbacks.insert(msg_id, tx);
264 let req = Msg {
265 src: self.node_id.clone(),
266 dest,
267 body: req_fn(msg_id),
268 };
269 trace!("svc: {}", req);
270 write!(core.write.as_mut(), "{}\n", req).expect("req write failed");
271 }
272 rx.await.expect("internal error: callback oneshot dropped")
273 }
274
275 #[allow(clippy::result_large_err)]
278 pub fn maybe_handle_service_res(&self, src: &NodeId, msg: Body) -> Result<(), Body> {
279 let in_reply_to = match msg.in_reply_to() {
280 Some(x) => x,
281 None => return Err(msg),
282 };
283
284 let mut core = self.core.lock().expect("mutex poisoned");
285 let callback = match core.callbacks.remove(&in_reply_to) {
286 Some(x) => x,
287 None => {
288 self.send_res(src.clone(), |msg_id| Body::Error {
289 msg_id: Some(msg_id),
290 in_reply_to,
291 code: ErrorCode::MalformedRequest,
292 text: format!("no callback expected for {:?}", in_reply_to),
293 });
294 return Ok(());
295 }
296 };
297
298 if let Err(_) = callback.send(msg) {
299 return Ok(());
301 }
302
303 Ok(())
304 }
305
306 pub async fn maybe_init_shard_id(&self) -> Result<ShardId, MaelstromError> {
310 let proposal = ShardId::new();
311 let key = "SHARD";
312 loop {
313 let from = Value::Null;
314 let to = Value::from(proposal.to_string());
315 match self
316 .lin_kv_compare_and_set(Value::from(key), from, to, Some(true))
317 .await
318 {
319 Ok(()) => {
320 info!("initialized maelstrom shard to {}", proposal);
321 return Ok(proposal);
322 }
323 Err(MaelstromError {
324 code: ErrorCode::PreconditionFailed,
325 ..
326 }) => match self.lin_kv_read(Value::from(key)).await? {
327 Some(value) => {
328 let value = value.as_str().ok_or_else(|| {
329 ExternalError::from(anyhow!("invalid SHARD {}", value))
330 })?;
331 let shard_id = value.parse::<ShardId>().map_err(|err| {
332 ExternalError::from(anyhow!("invalid SHARD {}: {}", value, err))
333 })?;
334 info!("fetched maelstrom shard id {}", shard_id);
335 return Ok(shard_id);
336 }
337 None => continue,
338 },
339 Err(err) => return Err(err),
340 }
341 }
342 }
343
344 pub async fn lin_kv_read(&self, key: Value) -> Result<Option<Value>, MaelstromError> {
346 let dest = NodeId("lin-kv".to_string());
347 let res = self
348 .send_service_req(dest, move |msg_id| Body::ReqLinKvRead { msg_id, key })
349 .await;
350 match res {
351 Body::Error {
352 code: ErrorCode::KeyDoesNotExist,
353 ..
354 } => Ok(None),
355 Body::Error { code, text, .. } => Err(MaelstromError { code, text }),
356 Body::ResLinKvRead { value, .. } => Ok(Some(value)),
357 res => unimplemented!("unsupported res: {:?}", res),
358 }
359 }
360
361 pub async fn lin_kv_write(&self, key: Value, value: Value) -> Result<(), MaelstromError> {
363 let dest = NodeId("lin-kv".to_string());
364 let res = self
365 .send_service_req(dest, move |msg_id| Body::ReqLinKvWrite {
366 msg_id,
367 key,
368 value,
369 })
370 .await;
371 match res {
372 Body::Error { code, text, .. } => Err(MaelstromError { code, text }),
373 Body::ResLinKvWrite { .. } => Ok(()),
374 res => unimplemented!("unsupported res: {:?}", res),
375 }
376 }
377
378 pub async fn lin_kv_compare_and_set(
380 &self,
381 key: Value,
382 from: Value,
383 to: Value,
384 create_if_not_exists: Option<bool>,
385 ) -> Result<(), MaelstromError> {
386 trace!(
387 "lin_kv_compare_and_set key={:?} from={:?} to={:?} create_if_not_exists={:?}",
388 key, from, to, create_if_not_exists
389 );
390 let dest = NodeId("lin-kv".to_string());
391 let res = self
392 .send_service_req(dest, move |msg_id| Body::ReqLinKvCaS {
393 msg_id,
394 key,
395 from,
396 to,
397 create_if_not_exists,
398 })
399 .await;
400 match res {
401 Body::Error { code, text, .. } => Err(MaelstromError { code, text }),
402 Body::ResLinKvCaS { .. } => Ok(()),
403 res => unimplemented!("unsupported res: {:?}", res),
404 }
405 }
406}
407
408#[derive(Debug)]
413struct AsyncInitOnceWaitable<T: Clone> {
414 core: tokio::sync::Mutex<(Option<T>, Vec<oneshot::Sender<T>>)>,
415}
416
417impl<T: Clone> AsyncInitOnceWaitable<T> {
418 pub fn new() -> Self {
419 let core = (None, Vec::new());
420 AsyncInitOnceWaitable {
421 core: tokio::sync::Mutex::new(core),
422 }
423 }
424
425 pub async fn init_once(&self, t: T) {
426 let mut core = self.core.lock().await;
427 assert!(core.0.is_none(), "init called more than once");
428 core.0 = Some(t.clone());
429 for tx in core.1.drain(..) {
430 let _ = tx.send(t.clone());
431 }
432 }
433
434 pub async fn get(&self) -> T {
435 let rx = {
436 let mut core = self.core.lock().await;
437 if let Some(x) = core.0.as_ref() {
438 return x.clone();
439 }
440 let (tx, rx) = tokio::sync::oneshot::channel();
441 core.1.push(tx);
442 rx
443 };
444 rx.await.expect("internal error: waiter oneshot dropped")
445 }
446}