mz_service/
transport.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! The Cluster Transport Protocol (CTP).
11//!
12//! CTP is the protocol used to transmit commands from controllers to replicas and responses from
13//! replicas to controllers. It runs on top of a reliable bidirectional connection stream, as
14//! provided by TCP or UDS, and adds message framing as well as heartbeating.
15//!
16//! CTP supports any message type that implements the serde [`Serialize`] and [`Deserialize`]
17//! traits. Messages are encoded using the [`bincode`] format and then sent over the wire with a
18//! length prefix.
19//!
20//! A CTP server only serves a single client at a time. If a new client connects while a connection
21//! is already established, the previous connection is canceled.
22
23mod metrics;
24
25use std::convert::Infallible;
26use std::fmt::Debug;
27use std::time::Duration;
28
29use anyhow::{anyhow, bail};
30use async_trait::async_trait;
31use bincode::Options;
32use futures::future;
33use mz_ore::cast::CastInto;
34use mz_ore::netio::{Listener, SocketAddr, Stream, TimedReader, TimedWriter};
35use mz_ore::task::{AbortOnDropHandle, JoinHandle, JoinHandleExt};
36use semver::Version;
37use serde::de::DeserializeOwned;
38use serde::{Deserialize, Serialize};
39use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
40use tokio::sync::{mpsc, oneshot, watch};
41use tracing::{debug, info, trace, warn};
42
43use crate::client::{GenericClient, Partitionable, Partitioned};
44
45pub use metrics::{Metrics, NoopMetrics};
46
47/// Trait for messages that can be sent over CTP.
48pub trait Message: Debug + Send + Sync + Serialize + DeserializeOwned + 'static {}
49impl<T: Debug + Send + Sync + Serialize + DeserializeOwned + 'static> Message for T {}
50
51/// A client for a CTP connection.
52#[derive(Debug)]
53pub struct Client<Out, In> {
54    conn: Connection<Out, In>,
55}
56
57impl<Out: Message, In: Message> Client<Out, In> {
58    /// Connect to the server at the given address.
59    ///
60    /// This call resolves once a connection with the server host was either established, was
61    /// rejected, or timed out.
62    pub async fn connect(
63        address: &str,
64        version: Version,
65        connect_timeout: Duration,
66        idle_timeout: Duration,
67        metrics: impl Metrics<Out, In>,
68    ) -> anyhow::Result<Self> {
69        let dest_host = host_from_address(address);
70        let stream = mz_ore::future::timeout(connect_timeout, Stream::connect(address)).await?;
71        info!(%address, "ctp: connected to server");
72
73        let conn = Connection::start(stream, version, dest_host, idle_timeout, metrics).await?;
74        Ok(Self { conn })
75    }
76}
77
78/// Helper function to extract the host part from an address string.
79///
80/// This function assumes addresses to be of the form `<host>:<port>` or `<protocol>:<host>:<port>`
81/// and yields `None` otherwise.
82fn host_from_address(address: &str) -> Option<String> {
83    let mut p = address.split(':');
84    let (host, port) = match (p.next(), p.next(), p.next(), p.next()) {
85        (Some(host), Some(port), None, None) => (host, port),
86        (Some(_protocol), Some(host), Some(port), None) => (host, port),
87        _ => return None,
88    };
89
90    let _: u16 = port.parse().ok()?;
91    Some(host.into())
92}
93
94impl<Out, In> Client<Out, In>
95where
96    Out: Message,
97    In: Message,
98    (Out, In): Partitionable<Out, In>,
99{
100    /// Create a `Partitioned` client that connects through CTP.
101    pub async fn connect_partitioned(
102        addresses: Vec<String>,
103        version: Version,
104        connect_timeout: Duration,
105        idle_timeout: Duration,
106        metrics: impl Metrics<Out, In>,
107    ) -> anyhow::Result<Partitioned<Self, Out, In>> {
108        let connects = addresses.iter().map(|addr| {
109            Self::connect(
110                addr,
111                version.clone(),
112                connect_timeout,
113                idle_timeout,
114                metrics.clone(),
115            )
116        });
117        let clients = future::try_join_all(connects).await?;
118        Ok(Partitioned::new(clients))
119    }
120}
121
122#[async_trait]
123impl<Out: Message, In: Message> GenericClient<Out, In> for Client<Out, In> {
124    async fn send(&mut self, cmd: Out) -> anyhow::Result<()> {
125        self.conn.send(cmd).await
126    }
127
128    /// # Cancel safety
129    ///
130    /// This method is cancel safe.
131    async fn recv(&mut self) -> anyhow::Result<Option<In>> {
132        // `Connection::recv` is documented to be cancel safe.
133        self.conn.recv().await.map(Some)
134    }
135}
136
137/// Spawn a CTP server that serves connections at the given address.
138pub async fn serve<In, Out, H>(
139    address: SocketAddr,
140    version: Version,
141    server_fqdn: Option<String>,
142    idle_timeout: Duration,
143    handler_fn: impl Fn() -> H,
144    metrics: impl Metrics<Out, In>,
145) -> anyhow::Result<()>
146where
147    In: Message,
148    Out: Message,
149    H: GenericClient<In, Out> + 'static,
150{
151    // Keep a handle to the task serving the current connection, as well as a cancelation token, so
152    // we can cancel it when a new client connects.
153    //
154    // Note that we cannot simply abort the previous connection task because its future isn't known
155    // to be cancel safe. Instead we pass the connection tasks a cancelation token and wait for
156    // them to shut themselves down gracefully once the token gets dropped.
157    let mut connection_task: Option<(JoinHandle<()>, oneshot::Sender<()>)> = None;
158
159    let listener = Listener::bind(&address).await?;
160    info!(%address, "ctp: listening for client connections");
161
162    loop {
163        let (stream, peer) = listener.accept().await?;
164        info!(%peer, "ctp: accepted client connection");
165
166        // Cancel any existing connection before starting to serve the new one.
167        if let Some((task, token)) = connection_task.take() {
168            drop(token);
169            task.wait_and_assert_finished().await;
170        }
171
172        let handler = handler_fn();
173        let version = version.clone();
174        let server_fqdn = server_fqdn.clone();
175        let metrics = metrics.clone();
176        let (cancel_tx, cancel_rx) = oneshot::channel();
177
178        let handle = mz_ore::task::spawn(|| "ctp::connection", async move {
179            let Err(error) = serve_connection(
180                stream,
181                handler,
182                version,
183                server_fqdn,
184                idle_timeout,
185                cancel_rx,
186                metrics,
187            )
188            .await;
189            info!("ctp: connection failed: {error}");
190        });
191
192        connection_task = Some((handle, cancel_tx));
193    }
194}
195
196/// Serve a single CTP connection.
197async fn serve_connection<In, Out, H>(
198    stream: Stream,
199    mut handler: H,
200    version: Version,
201    server_fqdn: Option<String>,
202    timeout: Duration,
203    cancel_rx: oneshot::Receiver<()>,
204    metrics: impl Metrics<Out, In>,
205) -> anyhow::Result<Infallible>
206where
207    In: Message,
208    Out: Message,
209    H: GenericClient<In, Out>,
210{
211    let mut conn = Connection::start(stream, version, server_fqdn, timeout, metrics).await?;
212
213    let mut cancel_rx = cancel_rx;
214    loop {
215        tokio::select! {
216            // `Connection::recv` is documented to be cancel safe.
217            inbound = conn.recv() => {
218                let msg = inbound?;
219                handler.send(msg).await?;
220            },
221            // `GenericClient::recv` is documented to be cancel safe.
222            outbound = handler.recv() => match outbound? {
223                Some(msg) => conn.send(msg).await?,
224                None => bail!("client disconnected"),
225            },
226            _ = &mut cancel_rx => bail!("connection canceled"),
227        }
228    }
229}
230
231/// An active CTP connection.
232///
233/// This type encapsulates the core connection logic. It is used by both the client and the server
234/// implementation, with swapped `Out`/`In` types.
235///
236/// Each connection spawns two tasks:
237///
238///  * The send task is responsible for encoding and sending enqueued messages.
239///  * The recv task is responsible for receiving and decoding messages from the peer.
240///
241/// The separation into tasks provides some performance isolation between the sending and the
242/// receiving half of the connection.
243#[derive(Debug)]
244struct Connection<Out, In> {
245    /// Message sender connected to the send task.
246    msg_tx: mpsc::Sender<Out>,
247    /// Message receiver connected to the receive task.
248    msg_rx: mpsc::Receiver<In>,
249    /// Receiver for errors encountered by connection tasks.
250    error_rx: watch::Receiver<String>,
251
252    /// Handles to connection tasks.
253    _tasks: [AbortOnDropHandle<()>; 2],
254}
255
256impl<Out: Message, In: Message> Connection<Out, In> {
257    /// The interval with which keepalives are emitted on idle connections.
258    const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(1);
259    /// The minimum acceptable idle timeout.
260    ///
261    /// We want this to be significantly greater than `KEEPALIVE_INTERVAL`, to avoid connections
262    /// getting canceled unnecessarily.
263    const MIN_TIMEOUT: Duration = Duration::from_secs(2);
264
265    /// Start a new connection wrapping the given stream.
266    async fn start(
267        stream: Stream,
268        version: Version,
269        server_fqdn: Option<String>,
270        mut timeout: Duration,
271        metrics: impl Metrics<Out, In>,
272    ) -> anyhow::Result<Self> {
273        if timeout < Self::MIN_TIMEOUT {
274            warn!(
275                ?timeout,
276                "ctp: configured timeout is less than minimum timeout",
277            );
278            timeout = Self::MIN_TIMEOUT;
279        }
280
281        let (reader, writer) = stream.split();
282
283        // Apply the timeout to all connection reads and writes.
284        let reader = TimedReader::new(reader, timeout);
285        let writer = TimedWriter::new(writer, timeout);
286        // Track byte count metrics for all connection reads and writes.
287        let mut reader = metrics::Reader::new(reader, metrics.clone());
288        let mut writer = metrics::Writer::new(writer, metrics.clone());
289
290        handshake(&mut reader, &mut writer, version, server_fqdn).await?;
291
292        let (out_tx, out_rx) = mpsc::channel(1024);
293        let (in_tx, in_rx) = mpsc::channel(1024);
294        // Initialize the error channel with a default error to return if none of the tasks
295        // produced an error.
296        let (error_tx, error_rx) = watch::channel("connection closed".into());
297
298        let send_task = mz_ore::task::spawn(
299            || "ctp::send",
300            Self::run_send_task(writer, out_rx, error_tx.clone(), metrics.clone()),
301        );
302        let recv_task = mz_ore::task::spawn(
303            || "ctp::recv",
304            Self::run_recv_task(reader, in_tx, error_tx, metrics),
305        );
306
307        Ok(Self {
308            msg_tx: out_tx,
309            msg_rx: in_rx,
310            error_rx,
311            _tasks: [send_task.abort_on_drop(), recv_task.abort_on_drop()],
312        })
313    }
314
315    /// Enqueue a message for sending.
316    async fn send(&mut self, msg: Out) -> anyhow::Result<()> {
317        match self.msg_tx.send(msg).await {
318            Ok(()) => Ok(()),
319            Err(_) => bail!(self.collect_error().await),
320        }
321    }
322
323    /// Return a received message.
324    ///
325    /// # Cancel safety
326    ///
327    /// This method is cancel safe.
328    async fn recv(&mut self) -> anyhow::Result<In> {
329        // `mpcs::Receiver::recv` is documented to be cancel safe.
330        match self.msg_rx.recv().await {
331            Some(msg) => Ok(msg),
332            None => bail!(self.collect_error().await),
333        }
334    }
335
336    /// Return a connection error.
337    async fn collect_error(&mut self) -> String {
338        // Wait for the first error to be reported, or for all connection tasks to shut down.
339        let _ = self.error_rx.changed().await;
340        // Mark the current value as unseen, so the next `collect_error` call can return
341        // immediately.
342        self.error_rx.mark_changed();
343
344        self.error_rx.borrow().clone()
345    }
346
347    /// Run a connection's send task.
348    async fn run_send_task<W: AsyncWrite + Unpin>(
349        mut writer: W,
350        mut msg_rx: mpsc::Receiver<Out>,
351        error_tx: watch::Sender<String>,
352        mut metrics: impl Metrics<Out, In>,
353    ) {
354        loop {
355            let msg = tokio::select! {
356                // `mpsc::UnboundedReceiver::recv` is cancel safe.
357                msg = msg_rx.recv() => match msg {
358                    Some(msg) => {
359                        trace!(?msg, "ctp: sending message");
360                        Some(msg)
361                    }
362                    None => break,
363                },
364                // `tokio::time::sleep` is cancel safe.
365                _ = tokio::time::sleep(Self::KEEPALIVE_INTERVAL) => {
366                    trace!("ctp: sending keepalive");
367                    None
368                },
369            };
370
371            if let Err(error) = write_message(&mut writer, msg.as_ref()).await {
372                debug!("ctp: send error: {error}");
373                let _ = error_tx.send(error.to_string());
374                break;
375            };
376
377            if let Some(msg) = &msg {
378                metrics.message_sent(msg);
379            }
380        }
381    }
382
383    /// Run a connection's recv task.
384    async fn run_recv_task<R: AsyncRead + Unpin>(
385        mut reader: R,
386        msg_tx: mpsc::Sender<In>,
387        error_tx: watch::Sender<String>,
388        mut metrics: impl Metrics<Out, In>,
389    ) {
390        loop {
391            match read_message(&mut reader).await {
392                Ok(msg) => {
393                    trace!(?msg, "ctp: received message");
394                    metrics.message_received(&msg);
395
396                    if msg_tx.send(msg).await.is_err() {
397                        break;
398                    }
399                }
400                Err(error) => {
401                    debug!("ctp: recv error: {error}");
402                    let _ = error_tx.send(error.to_string());
403                    break;
404                }
405            };
406        }
407    }
408}
409
410/// A connection handler that simply forwards messages over channels.
411#[derive(Debug)]
412pub struct ChannelHandler<In, Out> {
413    tx: mpsc::UnboundedSender<In>,
414    rx: mpsc::UnboundedReceiver<Out>,
415}
416
417impl<In, Out> ChannelHandler<In, Out> {
418    pub fn new(tx: mpsc::UnboundedSender<In>, rx: mpsc::UnboundedReceiver<Out>) -> Self {
419        Self { tx, rx }
420    }
421}
422
423#[async_trait]
424impl<In: Message, Out: Message> GenericClient<In, Out> for ChannelHandler<In, Out> {
425    async fn send(&mut self, cmd: In) -> anyhow::Result<()> {
426        let result = self.tx.send(cmd);
427        result.map_err(|_| anyhow!("client channel disconnected"))
428    }
429
430    /// # Cancel safety
431    ///
432    /// This method is cancel safe.
433    async fn recv(&mut self) -> anyhow::Result<Option<Out>> {
434        // `mpsc::UnboundedReceiver::recv` is cancel safe.
435        match self.rx.recv().await {
436            Some(resp) => Ok(Some(resp)),
437            None => bail!("client channel disconnected"),
438        }
439    }
440}
441
442/// Perform the CTP handshake.
443///
444/// To perform the handshake, each endpoint sends the protocol magic number, followed by a
445/// `Hello` message. The `Hello` message contains information about the originating endpoint that
446/// is used by the receiver to validate compatibility with its peer. Only if both endpoints
447/// determine that they are compatible does the handshake succeed.
448async fn handshake<R, W>(
449    mut reader: R,
450    mut writer: W,
451    version: Version,
452    server_fqdn: Option<String>,
453) -> anyhow::Result<()>
454where
455    R: AsyncRead + Unpin,
456    W: AsyncWrite + Unpin,
457{
458    /// A randomly chosen magic number identifying CTP connections.
459    const MAGIC: u64 = 0x477574656e546167;
460
461    writer.write_u64(MAGIC).await?;
462
463    let hello = Hello {
464        version: version.clone(),
465        server_fqdn: server_fqdn.clone(),
466    };
467    write_message(&mut writer, Some(&hello)).await?;
468
469    let peer_magic = reader.read_u64().await?;
470    if peer_magic != MAGIC {
471        bail!("invalid protocol magic: {peer_magic:#x}");
472    }
473
474    let Hello {
475        version: peer_version,
476        server_fqdn: peer_server_fqdn,
477    } = read_message(&mut reader).await?;
478
479    if peer_version != version {
480        bail!("version mismatch: {peer_version} != {version}");
481    }
482    if let (Some(other), Some(mine)) = (&peer_server_fqdn, &server_fqdn) {
483        if other != mine {
484            bail!("server FQDN mismatch: {other} != {mine}");
485        }
486    }
487
488    Ok(())
489}
490
491/// A message for exchanging compatibility information during the CTP handshake.
492#[derive(Debug, Serialize, Deserialize)]
493struct Hello {
494    /// The version of the originating endpoint.
495    version: Version,
496    /// The FQDN of the server endpoint.
497    server_fqdn: Option<String>,
498}
499
500/// Write a message into the given writer.
501///
502/// The message can be `None`, in which case an empty message is written. This is used to implement
503/// keepalives. At the receiver, empty messages are ignored, but they do reset the read timeout.
504async fn write_message<W, M>(mut writer: W, msg: Option<&M>) -> anyhow::Result<()>
505where
506    W: AsyncWrite + Unpin,
507    M: Message,
508{
509    let bytes = match msg {
510        Some(msg) => &*wire_encode(msg)?,
511        None => &[],
512    };
513
514    let len = bytes.len().cast_into();
515    writer.write_u64(len).await?;
516    writer.write_all(bytes).await?;
517
518    Ok(())
519}
520
521/// Read a message from the given reader.
522async fn read_message<R, M>(mut reader: R) -> anyhow::Result<M>
523where
524    R: AsyncRead + Unpin,
525    M: Message,
526{
527    // Skip over any empty messages (i.e. keepalives).
528    let mut len = 0;
529    while len == 0 {
530        len = reader.read_u64().await?;
531    }
532
533    let mut bytes = vec![0; len.cast_into()];
534    reader.read_exact(&mut bytes).await?;
535
536    wire_decode(&bytes)
537}
538
539/// Encode a message for wire transport.
540fn wire_encode<M: Message>(msg: &M) -> anyhow::Result<Vec<u8>> {
541    let bytes = bincode::DefaultOptions::new().serialize(msg)?;
542    Ok(bytes)
543}
544
545/// Decode a wire frame back into a message.
546fn wire_decode<M: Message>(bytes: &[u8]) -> anyhow::Result<M> {
547    let msg = bincode::DefaultOptions::new().deserialize(bytes)?;
548    Ok(msg)
549}