persistcli/maelstrom/
node.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! A driver for interacting with Maelstrom
11//!
12//! This translates input into requests and requests into output. It also
13//! handles issuing Maelstrom [service] requests. It's very roughly based off of
14//! the node in Maelstrom's [ruby examples].
15//!
16//! [service]: https://github.com/jepsen-io/maelstrom/blob/v0.2.1/doc/services.md
17//! [ruby examples]: https://github.com/jepsen-io/maelstrom/blob/v0.2.1/demo/ruby/node.rb
18
19use std::collections::BTreeMap;
20use std::io::{BufRead, Write};
21use std::sync::{Arc, Mutex};
22
23use anyhow::anyhow;
24use async_trait::async_trait;
25use mz_persist::location::ExternalError;
26use mz_persist_client::ShardId;
27use serde_json::Value;
28use tokio::sync::oneshot;
29use tracing::{Instrument, debug_span, info, trace};
30
31use crate::maelstrom::Args;
32use crate::maelstrom::api::{Body, ErrorCode, MaelstromError, Msg, MsgId, NodeId};
33
34/// An implementor of a Maelstrom [workload].
35///
36/// [workload]: https://github.com/jepsen-io/maelstrom/blob/v0.2.1/doc/workload.md
37#[async_trait]
38pub trait Service: Sized + Send + Sync {
39    /// Construct this service.
40    ///
41    /// Maelstrom services are available via the [Handle].
42    async fn init(args: &Args, handle: &Handle) -> Result<Self, MaelstromError>;
43
44    /// Respond to a single request.
45    ///
46    /// Implementations must either panic or respond by calling
47    /// [Handle::send_res] exactly once.
48    async fn eval(&self, handle: Handle, src: NodeId, req: Body);
49}
50
51/// Runs the RPC loop, accepting Maelstrom workload requests, issuing responses,
52/// and communicating with Maelstrom services.
53pub fn run<R, W, S>(args: Args, read: R, write: W) -> Result<(), anyhow::Error>
54where
55    R: BufRead,
56    W: Write + Send + Sync + 'static,
57    S: Service + 'static,
58{
59    let mut node = Node::<S>::new(args, write);
60    for line in read.lines() {
61        let line = line.map_err(|err| anyhow!("req read failed: {}", err))?;
62        trace!("raw: [{}]", line);
63        let req: Msg = line
64            .parse()
65            .map_err(|err| anyhow!("invalid req {}: {}", err, line))?;
66        if req.body.in_reply_to().is_none() {
67            info!("req: {}", req);
68        } else {
69            trace!("req: {}", req);
70        }
71        node.handle(req);
72    }
73    Ok(())
74}
75
76struct Node<S>
77where
78    S: Service + 'static,
79{
80    args: Args,
81    core: Arc<Mutex<Core>>,
82    node_id: Option<NodeId>,
83    service: Arc<AsyncInitOnceWaitable<Arc<S>>>,
84}
85
86impl<S> Node<S>
87where
88    S: Service + 'static,
89{
90    fn new<W>(args: Args, write: W) -> Self
91    where
92        W: Write + Send + Sync + 'static,
93    {
94        let core = Core {
95            write: Box::new(write),
96            next_msg_id: MsgId(0),
97            callbacks: BTreeMap::new(),
98        };
99        Node {
100            args,
101            core: Arc::new(Mutex::new(core)),
102            node_id: None,
103            service: Arc::new(AsyncInitOnceWaitable::new()),
104        }
105    }
106
107    pub fn handle(&mut self, msg: Msg) {
108        // If we've been initialized (i.e. have a NodeId), respond to the
109        // message.
110        if let Some(node_id) = self.node_id.as_ref() {
111            // This message is not for us
112            if node_id != &msg.dest {
113                return;
114            }
115
116            let handle = Handle {
117                node_id: node_id.clone(),
118                core: Arc::clone(&self.core),
119            };
120
121            let body = match handle.maybe_handle_service_res(&msg.src, msg.body) {
122                Ok(()) => return,
123                Err(x) => x,
124            };
125
126            let service = Arc::clone(&self.service);
127            mz_ore::task::spawn(
128                || "maelstrom::handle".to_string(),
129                async move {
130                    let service = service.get().await;
131                    let () = service.eval(handle, msg.src, body).await;
132                }
133                .instrument(debug_span!("maelstrom::handle")),
134            );
135            return;
136        }
137
138        // Otherwise, if we haven't yet been initialized, then the only message
139        // type we are allowed to process is ReqInit.
140        match msg.body {
141            Body::ReqInit {
142                msg_id, node_id, ..
143            } => {
144                // Set the NodeId.
145                self.node_id = Some(node_id.clone());
146                let handle = Handle {
147                    node_id,
148                    core: Arc::clone(&self.core),
149                };
150
151                // Respond to the init req.
152                //
153                // NB: This must come _before_ service init! We want service
154                // init to be able to use Maelstrom services, but Maelstrom
155                // doesn't make services available to nodes that haven't yet
156                // responded to init.
157                let in_reply_to = msg_id;
158                handle.send_res(msg.src, move |msg_id| Body::ResInit {
159                    msg_id,
160                    in_reply_to,
161                });
162
163                // Tricky! Run the service init in a task in case it uses
164                // Maelstrom services. This is because Maelstrom services return
165                // responses on stdin, which means we need to be processing the
166                // run loop concurrently with this. This is also the reason for
167                // the AsyncInitOnceWaitable nonsense.
168                let args = self.args.clone();
169                let service_init = Arc::clone(&self.service);
170                mz_ore::task::spawn(
171                    || "maelstrom::init".to_string(),
172                    async move {
173                        let service = match S::init(&args, &handle).await {
174                            Ok(x) => x,
175                            Err(err) => {
176                                // If service initialization fails, there's nothing
177                                // to do but panic. Any retries should be pushed
178                                // into the impl of `init`.
179                                panic!("service initialization failed: {}", err);
180                            }
181                        };
182                        service_init.init_once(Arc::new(service)).await;
183                    }
184                    .instrument(debug_span!("maelstrom::init")),
185                );
186            }
187            // All other reqs are a no-op. We can't even error without a NodeId.
188            _ => {}
189        }
190    }
191}
192
193struct Core {
194    write: Box<dyn Write + Send + Sync>,
195    next_msg_id: MsgId,
196    callbacks: BTreeMap<MsgId, oneshot::Sender<Body>>,
197}
198
199impl std::fmt::Debug for Core {
200    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
201        // Destructure the struct to be defensive against new fields.
202        let Core {
203            write: _,
204            next_msg_id,
205            callbacks,
206        } = self;
207        f.debug_struct("Core")
208            .field("next_msg_id", &next_msg_id)
209            .field("callbacks", &callbacks.keys().collect::<Vec<_>>())
210            .finish_non_exhaustive()
211    }
212}
213
214impl Core {
215    fn alloc_msg_id(&mut self) -> MsgId {
216        self.next_msg_id = self.next_msg_id.next();
217        self.next_msg_id
218    }
219}
220
221/// A handle to interact with Node.
222#[derive(Debug, Clone)]
223pub struct Handle {
224    node_id: NodeId,
225    core: Arc<Mutex<Core>>,
226}
227
228impl Handle {
229    /// Returns this handle's NodeId.
230    pub fn node_id(&self) -> NodeId {
231        self.node_id.clone()
232    }
233
234    /// Send a response to Maelstrom.
235    ///
236    /// `dest` should be the `src` of the response. To make a service request,
237    /// use [Self::send_service_req] instead.
238    pub fn send_res<BodyFn: FnOnce(MsgId) -> Body>(&self, dest: NodeId, res_fn: BodyFn) {
239        let mut core = self.core.lock().expect("mutex poisoned");
240        let msg_id = core.alloc_msg_id();
241        let res = Msg {
242            src: self.node_id.clone(),
243            dest,
244            body: res_fn(msg_id),
245        };
246        info!("res: {}", res);
247        write!(core.write.as_mut(), "{}\n", res).expect("res write failed");
248    }
249
250    /// Issue a service request to Maelstrom.
251    ///
252    /// `dest` should be the service name. To respond to a request, use
253    /// [Self::send_res] instead.
254    pub async fn send_service_req<BodyFn: FnOnce(MsgId) -> Body>(
255        &self,
256        dest: NodeId,
257        req_fn: BodyFn,
258    ) -> Body {
259        let (tx, rx) = oneshot::channel();
260        {
261            let mut core = self.core.lock().expect("mutex poisoned");
262            let msg_id = core.alloc_msg_id();
263            core.callbacks.insert(msg_id, tx);
264            let req = Msg {
265                src: self.node_id.clone(),
266                dest,
267                body: req_fn(msg_id),
268            };
269            trace!("svc: {}", req);
270            write!(core.write.as_mut(), "{}\n", req).expect("req write failed");
271        }
272        rx.await.expect("internal error: callback oneshot dropped")
273    }
274
275    /// Attempts to handle a msg as a service response, returning it back if it
276    /// isn't one.
277    #[allow(clippy::result_large_err)]
278    pub fn maybe_handle_service_res(&self, src: &NodeId, msg: Body) -> Result<(), Body> {
279        let in_reply_to = match msg.in_reply_to() {
280            Some(x) => x,
281            None => return Err(msg),
282        };
283
284        let mut core = self.core.lock().expect("mutex poisoned");
285        let callback = match core.callbacks.remove(&in_reply_to) {
286            Some(x) => x,
287            None => {
288                self.send_res(src.clone(), |msg_id| Body::Error {
289                    msg_id: Some(msg_id),
290                    in_reply_to,
291                    code: ErrorCode::MalformedRequest,
292                    text: format!("no callback expected for {:?}", in_reply_to),
293                });
294                return Ok(());
295            }
296        };
297
298        if let Err(_) = callback.send(msg) {
299            // The caller is no longer listening. This is safe to ignore.
300            return Ok(());
301        }
302
303        Ok(())
304    }
305
306    /// Returns a [ShardId] for this Maelstrom run.
307    ///
308    /// Uses Maelstrom services to ensure all nodes end up with the same id.
309    pub async fn maybe_init_shard_id(&self) -> Result<ShardId, MaelstromError> {
310        let proposal = ShardId::new();
311        let key = "SHARD";
312        loop {
313            let from = Value::Null;
314            let to = Value::from(proposal.to_string());
315            match self
316                .lin_kv_compare_and_set(Value::from(key), from, to, Some(true))
317                .await
318            {
319                Ok(()) => {
320                    info!("initialized maelstrom shard to {}", proposal);
321                    return Ok(proposal);
322                }
323                Err(MaelstromError {
324                    code: ErrorCode::PreconditionFailed,
325                    ..
326                }) => match self.lin_kv_read(Value::from(key)).await? {
327                    Some(value) => {
328                        let value = value.as_str().ok_or_else(|| {
329                            ExternalError::from(anyhow!("invalid SHARD {}", value))
330                        })?;
331                        let shard_id = value.parse::<ShardId>().map_err(|err| {
332                            ExternalError::from(anyhow!("invalid SHARD {}: {}", value, err))
333                        })?;
334                        info!("fetched maelstrom shard id {}", shard_id);
335                        return Ok(shard_id);
336                    }
337                    None => continue,
338                },
339                Err(err) => return Err(err),
340            }
341        }
342    }
343
344    /// Issues a Maelstrom lin-kv service read request.
345    pub async fn lin_kv_read(&self, key: Value) -> Result<Option<Value>, MaelstromError> {
346        let dest = NodeId("lin-kv".to_string());
347        let res = self
348            .send_service_req(dest, move |msg_id| Body::ReqLinKvRead { msg_id, key })
349            .await;
350        match res {
351            Body::Error {
352                code: ErrorCode::KeyDoesNotExist,
353                ..
354            } => Ok(None),
355            Body::Error { code, text, .. } => Err(MaelstromError { code, text }),
356            Body::ResLinKvRead { value, .. } => Ok(Some(value)),
357            res => unimplemented!("unsupported res: {:?}", res),
358        }
359    }
360
361    /// Issues a Maelstrom lin-kv service write request.
362    pub async fn lin_kv_write(&self, key: Value, value: Value) -> Result<(), MaelstromError> {
363        let dest = NodeId("lin-kv".to_string());
364        let res = self
365            .send_service_req(dest, move |msg_id| Body::ReqLinKvWrite {
366                msg_id,
367                key,
368                value,
369            })
370            .await;
371        match res {
372            Body::Error { code, text, .. } => Err(MaelstromError { code, text }),
373            Body::ResLinKvWrite { .. } => Ok(()),
374            res => unimplemented!("unsupported res: {:?}", res),
375        }
376    }
377
378    /// Issues a Maelstrom lin-kv service cas request.
379    pub async fn lin_kv_compare_and_set(
380        &self,
381        key: Value,
382        from: Value,
383        to: Value,
384        create_if_not_exists: Option<bool>,
385    ) -> Result<(), MaelstromError> {
386        trace!(
387            "lin_kv_compare_and_set key={:?} from={:?} to={:?} create_if_not_exists={:?}",
388            key, from, to, create_if_not_exists
389        );
390        let dest = NodeId("lin-kv".to_string());
391        let res = self
392            .send_service_req(dest, move |msg_id| Body::ReqLinKvCaS {
393                msg_id,
394                key,
395                from,
396                to,
397                create_if_not_exists,
398            })
399            .await;
400        match res {
401            Body::Error { code, text, .. } => Err(MaelstromError { code, text }),
402            Body::ResLinKvCaS { .. } => Ok(()),
403            res => unimplemented!("unsupported res: {:?}", res),
404        }
405    }
406}
407
408/// A helper for a value that is initialized once, but used from many async
409/// places.
410///
411/// This name sure is a mouthful. Anyone have a suggestion?
412#[derive(Debug)]
413struct AsyncInitOnceWaitable<T: Clone> {
414    core: tokio::sync::Mutex<(Option<T>, Vec<oneshot::Sender<T>>)>,
415}
416
417impl<T: Clone> AsyncInitOnceWaitable<T> {
418    pub fn new() -> Self {
419        let core = (None, Vec::new());
420        AsyncInitOnceWaitable {
421            core: tokio::sync::Mutex::new(core),
422        }
423    }
424
425    pub async fn init_once(&self, t: T) {
426        let mut core = self.core.lock().await;
427        assert!(core.0.is_none(), "init called more than once");
428        core.0 = Some(t.clone());
429        for tx in core.1.drain(..) {
430            let _ = tx.send(t.clone());
431        }
432    }
433
434    pub async fn get(&self) -> T {
435        let rx = {
436            let mut core = self.core.lock().await;
437            if let Some(x) = core.0.as_ref() {
438                return x.clone();
439            }
440            let (tx, rx) = tokio::sync::oneshot::channel();
441            core.1.push(tx);
442            rx
443        };
444        rx.await.expect("internal error: waiter oneshot dropped")
445    }
446}