mz_service/
client.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Traits for client–server communication independent of transport layer.
11//!
12//! These traits are designed for servers where where commands must be sharded
13//! among several worker threads or processes.
14
15use std::fmt;
16use std::pin::Pin;
17
18use async_trait::async_trait;
19use futures::stream::{Stream, StreamExt};
20use tokio_stream::StreamMap;
21use tracing::trace;
22
23/// A generic client to a server that receives commands and asynchronously
24/// produces responses.
25#[async_trait]
26pub trait GenericClient<C, R>: fmt::Debug + Send {
27    /// Sends a command to the dataflow server.
28    ///
29    /// The command can error for various reasons.
30    async fn send(&mut self, cmd: C) -> Result<(), anyhow::Error>;
31
32    /// Receives the next response from the dataflow server.
33    ///
34    /// This method blocks until the next response is available.
35    ///
36    /// A return value of `Ok(Some(_))` transmits a response.
37    ///
38    /// A return value of `Ok(None)` indicates graceful termination of the
39    /// connection. The owner of the client should not call `recv` again.
40    ///
41    /// A return value of `Err(_)` indicates an unrecoverable error. After
42    /// observing an error, the owner of the client must drop the client.
43    ///
44    /// Implementations of this method **must** be [cancellation safe]. That
45    /// means that work must not be lost if the future returned by this method
46    /// is dropped.
47    ///
48    /// [cancellation safe]: https://docs.rs/tokio/latest/tokio/macro.select.html#cancellation-safety
49    async fn recv(&mut self) -> Result<Option<R>, anyhow::Error>;
50
51    /// Returns an adapter that treats the client as a stream.
52    ///
53    /// The stream produces the responses that would be produced by repeated
54    /// calls to `recv`.
55    ///
56    /// # Cancel safety
57    ///
58    /// The returned stream is cancel safe. If `stream.next()` is used as the event in a
59    /// [`tokio::select!`] statement and some other branch completes first, it is guaranteed that
60    /// no messages were received by this client.
61    fn as_stream<'a>(
62        &'a mut self,
63    ) -> Pin<Box<dyn Stream<Item = Result<R, anyhow::Error>> + Send + 'a>>
64    where
65        R: Send + 'a,
66    {
67        Box::pin(async_stream::stream!({
68            loop {
69                // `GenericClient::recv` is required to be cancel safe.
70                match self.recv().await {
71                    Ok(Some(response)) => yield Ok(response),
72                    Err(error) => yield Err(error),
73                    Ok(None) => {
74                        return;
75                    }
76                }
77            }
78        }))
79    }
80}
81
82#[async_trait]
83impl<C, R> GenericClient<C, R> for Box<dyn GenericClient<C, R>>
84where
85    C: Send,
86{
87    async fn send(&mut self, cmd: C) -> Result<(), anyhow::Error> {
88        (**self).send(cmd).await
89    }
90
91    /// # Cancel safety
92    ///
93    /// This method is cancel safe. If `recv` is used as the event in a [`tokio::select!`]
94    /// statement and some other branch completes first, it is guaranteed that no messages were
95    /// received by this client.
96    async fn recv(&mut self) -> Result<Option<R>, anyhow::Error> {
97        // `GenericClient::recv` is required to be cancel safe.
98        (**self).recv().await
99    }
100}
101
102/// A client whose implementation is partitioned across a number of other
103/// clients.
104///
105/// Such a client needs to broadcast (partitioned) commands to all of its
106/// clients, and await responses from each of the client partitions before it
107/// can respond.
108#[derive(Debug)]
109pub struct Partitioned<P, C, R>
110where
111    (C, R): Partitionable<C, R>,
112{
113    /// The individual partitions representing per-worker clients.
114    pub parts: Vec<P>,
115    /// The partitioned state.
116    state: <(C, R) as Partitionable<C, R>>::PartitionedState,
117}
118
119impl<P, C, R> Partitioned<P, C, R>
120where
121    (C, R): Partitionable<C, R>,
122{
123    /// Create a client partitioned across multiple client shards.
124    pub fn new(parts: Vec<P>) -> Self {
125        Self {
126            state: <(C, R) as Partitionable<C, R>>::new(parts.len()),
127            parts,
128        }
129    }
130}
131
132#[async_trait]
133impl<P, C, R> GenericClient<C, R> for Partitioned<P, C, R>
134where
135    P: GenericClient<C, R>,
136    (C, R): Partitionable<C, R>,
137    C: fmt::Debug + Send,
138    R: fmt::Debug + Send,
139{
140    async fn send(&mut self, cmd: C) -> Result<(), anyhow::Error> {
141        trace!(command = ?cmd, "splitting command");
142        let cmd_parts = self.state.split_command(cmd);
143        for (index, (shard, cmd_part)) in self.parts.iter_mut().zip(cmd_parts).enumerate() {
144            if let Some(cmd) = cmd_part {
145                trace!(shard = ?index, command = ?cmd, "sending command");
146                shard.send(cmd).await?;
147            }
148        }
149        Ok(())
150    }
151
152    /// # Cancel safety
153    ///
154    /// This method is cancel safe. If `recv` is used as the event in a [`tokio::select!`]
155    /// statement and some other branch completes first, it is guaranteed that no messages were
156    /// received by this client.
157    async fn recv(&mut self) -> Result<Option<R>, anyhow::Error> {
158        let mut stream: StreamMap<_, _> = self
159            .parts
160            .iter_mut()
161            .map(|shard| shard.as_stream())
162            .enumerate()
163            .collect();
164
165        // `stream` is a cancel safe stream: It only awaits streams created with
166        // `GenericClient::as_stream`, which is documented to produce cancel safe streams.
167        // Thus no messages are lost if `stream` is dropped while awaiting its next element.
168        while let Some((index, response)) = stream.next().await {
169            match response {
170                Err(e) => {
171                    return Err(e);
172                }
173                Ok(response) => {
174                    trace!(shard = ?index, response = ?response, "received response");
175                    if let Some(response) = self.state.absorb_response(index, response) {
176                        trace!(response = ?response, "returning response");
177                        return response.map(Some);
178                    }
179                }
180            }
181        }
182        // Indicate completion of the communication.
183        Ok(None)
184    }
185}
186
187/// A trait for command–response pairs that can be partitioned across multiple
188/// workers via [`Partitioned`].
189pub trait Partitionable<C, R> {
190    /// The type which functions as the state machine for the partitioning.
191    type PartitionedState: PartitionedState<C, R>;
192
193    /// Construct a [`PartitionedState`] for the command–response pair.
194    fn new(parts: usize) -> Self::PartitionedState;
195}
196
197/// A state machine for a partitioned client that partitions commands across and
198/// amalgamates responses from multiple partitions.
199pub trait PartitionedState<C, R>: fmt::Debug + Send {
200    /// Splits a command into multiple partitions.
201    fn split_command(&mut self, command: C) -> Vec<Option<C>>;
202
203    /// Absorbs a response from a single partition.
204    ///
205    /// If responses from all partitions have been absorbed, returns an
206    /// amalgamated response.
207    fn absorb_response(&mut self, shard_id: usize, response: R)
208    -> Option<Result<R, anyhow::Error>>;
209}