mz_service/
client.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Traits for client–server communication independent of transport layer.
11//!
12//! These traits are designed for servers where where commands must be sharded
13//! among several worker threads or processes.
14
15use std::fmt;
16use std::pin::Pin;
17
18use async_trait::async_trait;
19use futures::stream::{Stream, StreamExt};
20use itertools::Itertools;
21use tokio_stream::StreamMap;
22use tracing::trace;
23
24/// A generic client to a server that receives commands and asynchronously
25/// produces responses.
26#[async_trait]
27pub trait GenericClient<C, R>: fmt::Debug + Send {
28    /// Sends a command to the dataflow server.
29    ///
30    /// The command can error for various reasons.
31    async fn send(&mut self, cmd: C) -> Result<(), anyhow::Error>;
32
33    /// Receives the next response from the dataflow server.
34    ///
35    /// This method blocks until the next response is available.
36    ///
37    /// A return value of `Ok(Some(_))` transmits a response.
38    ///
39    /// A return value of `Ok(None)` indicates graceful termination of the
40    /// connection. The owner of the client should not call `recv` again.
41    ///
42    /// A return value of `Err(_)` indicates an unrecoverable error. After
43    /// observing an error, the owner of the client must drop the client.
44    ///
45    /// Implementations of this method **must** be [cancellation safe]. That
46    /// means that work must not be lost if the future returned by this method
47    /// is dropped.
48    ///
49    /// [cancellation safe]: https://docs.rs/tokio/latest/tokio/macro.select.html#cancellation-safety
50    async fn recv(&mut self) -> Result<Option<R>, anyhow::Error>;
51
52    /// Returns an adapter that treats the client as a stream.
53    ///
54    /// The stream produces the responses that would be produced by repeated
55    /// calls to `recv`.
56    ///
57    /// # Cancel safety
58    ///
59    /// The returned stream is cancel safe. If `stream.next()` is used as the event in a
60    /// [`tokio::select!`] statement and some other branch completes first, it is guaranteed that
61    /// no messages were received by this client.
62    fn as_stream<'a>(
63        &'a mut self,
64    ) -> Pin<Box<dyn Stream<Item = Result<R, anyhow::Error>> + Send + 'a>>
65    where
66        R: Send + 'a,
67    {
68        Box::pin(async_stream::stream!({
69            loop {
70                // `GenericClient::recv` is required to be cancel safe.
71                match self.recv().await {
72                    Ok(Some(response)) => yield Ok(response),
73                    Err(error) => yield Err(error),
74                    Ok(None) => {
75                        return;
76                    }
77                }
78            }
79        }))
80    }
81}
82
83#[async_trait]
84impl<C, R> GenericClient<C, R> for Box<dyn GenericClient<C, R>>
85where
86    C: Send,
87{
88    async fn send(&mut self, cmd: C) -> Result<(), anyhow::Error> {
89        (**self).send(cmd).await
90    }
91
92    /// # Cancel safety
93    ///
94    /// This method is cancel safe. If `recv` is used as the event in a [`tokio::select!`]
95    /// statement and some other branch completes first, it is guaranteed that no messages were
96    /// received by this client.
97    async fn recv(&mut self) -> Result<Option<R>, anyhow::Error> {
98        // `GenericClient::recv` is required to be cancel safe.
99        (**self).recv().await
100    }
101}
102
103/// A client whose implementation is partitioned across a number of other
104/// clients.
105///
106/// Such a client needs to broadcast (partitioned) commands to all of its
107/// clients, and await responses from each of the client partitions before it
108/// can respond.
109#[derive(Debug)]
110pub struct Partitioned<P, C, R>
111where
112    (C, R): Partitionable<C, R>,
113{
114    /// The individual partitions representing per-worker clients.
115    pub parts: Vec<P>,
116    /// The partitioned state.
117    state: <(C, R) as Partitionable<C, R>>::PartitionedState,
118}
119
120impl<P, C, R> Partitioned<P, C, R>
121where
122    (C, R): Partitionable<C, R>,
123{
124    /// Create a client partitioned across multiple client shards.
125    pub fn new(parts: Vec<P>) -> Self {
126        Self {
127            state: <(C, R) as Partitionable<C, R>>::new(parts.len()),
128            parts,
129        }
130    }
131}
132
133#[async_trait]
134impl<P, C, R> GenericClient<C, R> for Partitioned<P, C, R>
135where
136    P: GenericClient<C, R>,
137    (C, R): Partitionable<C, R>,
138    C: fmt::Debug + Send,
139    R: fmt::Debug + Send,
140{
141    async fn send(&mut self, cmd: C) -> Result<(), anyhow::Error> {
142        trace!(command = ?cmd, "splitting command");
143        let cmd_parts = self.state.split_command(cmd);
144        for (index, (shard, cmd_part)) in self.parts.iter_mut().zip_eq(cmd_parts).enumerate() {
145            if let Some(cmd) = cmd_part {
146                trace!(shard = ?index, command = ?cmd, "sending command");
147                shard.send(cmd).await?;
148            }
149        }
150        Ok(())
151    }
152
153    /// # Cancel safety
154    ///
155    /// This method is cancel safe. If `recv` is used as the event in a [`tokio::select!`]
156    /// statement and some other branch completes first, it is guaranteed that no messages were
157    /// received by this client.
158    async fn recv(&mut self) -> Result<Option<R>, anyhow::Error> {
159        let mut stream: StreamMap<_, _> = self
160            .parts
161            .iter_mut()
162            .map(|shard| shard.as_stream())
163            .enumerate()
164            .collect();
165
166        // `stream` is a cancel safe stream: It only awaits streams created with
167        // `GenericClient::as_stream`, which is documented to produce cancel safe streams.
168        // Thus no messages are lost if `stream` is dropped while awaiting its next element.
169        while let Some((index, response)) = stream.next().await {
170            match response {
171                Err(e) => {
172                    return Err(e);
173                }
174                Ok(response) => {
175                    trace!(shard = ?index, response = ?response, "received response");
176                    if let Some(response) = self.state.absorb_response(index, response) {
177                        trace!(response = ?response, "returning response");
178                        return response.map(Some);
179                    }
180                }
181            }
182        }
183        // Indicate completion of the communication.
184        Ok(None)
185    }
186}
187
188/// A trait for command–response pairs that can be partitioned across multiple
189/// workers via [`Partitioned`].
190pub trait Partitionable<C, R> {
191    /// The type which functions as the state machine for the partitioning.
192    type PartitionedState: PartitionedState<C, R>;
193
194    /// Construct a [`PartitionedState`] for the command–response pair.
195    fn new(parts: usize) -> Self::PartitionedState;
196}
197
198/// A state machine for a partitioned client that partitions commands across and
199/// amalgamates responses from multiple partitions.
200pub trait PartitionedState<C, R>: fmt::Debug + Send {
201    /// Splits a command into multiple partitions.
202    fn split_command(&mut self, command: C) -> Vec<Option<C>>;
203
204    /// Absorbs a response from a single partition.
205    ///
206    /// If responses from all partitions have been absorbed, returns an
207    /// amalgamated response.
208    fn absorb_response(&mut self, shard_id: usize, response: R)
209    -> Option<Result<R, anyhow::Error>>;
210}