Skip to main content

mz_service/
transport.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! The Cluster Transport Protocol (CTP).
11//!
12//! CTP is the protocol used to transmit commands from controllers to replicas and responses from
13//! replicas to controllers. It runs on top of a reliable bidirectional connection stream, as
14//! provided by TCP or UDS, and adds message framing as well as heartbeating.
15//!
16//! CTP supports any message type that implements the serde [`Serialize`] and [`Deserialize`]
17//! traits. Messages are encoded using the [`bincode`] format and then sent over the wire with a
18//! length prefix.
19//!
20//! A CTP server only serves a single client at a time. If a new client connects while a connection
21//! is already established, the previous connection is canceled.
22
23mod metrics;
24
25use std::convert::Infallible;
26use std::fmt::Debug;
27use std::time::Duration;
28
29use anyhow::bail;
30use async_trait::async_trait;
31use bincode::Options;
32use futures::future;
33use mz_ore::cast::CastInto;
34use mz_ore::netio::{Listener, SocketAddr, Stream, TimedReader, TimedWriter};
35use mz_ore::task::{AbortOnDropHandle, JoinHandle};
36use semver::Version;
37use serde::de::DeserializeOwned;
38use serde::{Deserialize, Serialize};
39use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
40use tokio::sync::{mpsc, oneshot, watch};
41use tracing::{Instrument, debug, info, trace, warn};
42
43use crate::client::{GenericClient, Partitionable, Partitioned};
44
45pub use metrics::{Metrics, NoopMetrics};
46
47/// Trait for messages that can be sent over CTP.
48pub trait Message: Debug + Send + Sync + Serialize + DeserializeOwned + 'static {}
49impl<T: Debug + Send + Sync + Serialize + DeserializeOwned + 'static> Message for T {}
50
51/// A client for a CTP connection.
52#[derive(Debug)]
53pub struct Client<Out, In> {
54    conn: Connection<Out, In>,
55}
56
57impl<Out: Message, In: Message> Client<Out, In> {
58    /// Connect to the server at the given address.
59    ///
60    /// This call resolves once a connection with the server host was either established, was
61    /// rejected, or timed out.
62    pub async fn connect(
63        address: &str,
64        version: Version,
65        connect_timeout: Duration,
66        idle_timeout: Duration,
67        metrics: impl Metrics<Out, In>,
68    ) -> anyhow::Result<Self> {
69        let dest_host = host_from_address(address);
70        let stream = mz_ore::future::timeout(connect_timeout, Stream::connect(address)).await?;
71        info!(%address, "ctp: connected to server");
72
73        let conn = Connection::start(stream, version, dest_host, idle_timeout, metrics).await?;
74        Ok(Self { conn })
75    }
76}
77
78/// Helper function to extract the host part from an address string.
79///
80/// This function assumes addresses to be of the form `<host>:<port>` or `<protocol>:<host>:<port>`
81/// and yields `None` otherwise.
82fn host_from_address(address: &str) -> Option<String> {
83    let mut p = address.split(':');
84    let (host, port) = match (p.next(), p.next(), p.next(), p.next()) {
85        (Some(host), Some(port), None, None) => (host, port),
86        (Some(_protocol), Some(host), Some(port), None) => (host, port),
87        _ => return None,
88    };
89
90    let _: u16 = port.parse().ok()?;
91    Some(host.into())
92}
93
94impl<Out, In> Client<Out, In>
95where
96    Out: Message,
97    In: Message,
98    (Out, In): Partitionable<Out, In>,
99{
100    /// Create a `Partitioned` client that connects through CTP.
101    pub async fn connect_partitioned(
102        addresses: Vec<String>,
103        version: Version,
104        connect_timeout: Duration,
105        idle_timeout: Duration,
106        metrics: impl Metrics<Out, In>,
107    ) -> anyhow::Result<Partitioned<Self, Out, In>> {
108        let connects = addresses.iter().map(|addr| {
109            Self::connect(
110                addr,
111                version.clone(),
112                connect_timeout,
113                idle_timeout,
114                metrics.clone(),
115            )
116        });
117        let clients = future::try_join_all(connects).await?;
118        Ok(Partitioned::new(clients))
119    }
120}
121
122#[async_trait]
123impl<Out: Message, In: Message> GenericClient<Out, In> for Client<Out, In> {
124    async fn send(&mut self, cmd: Out) -> anyhow::Result<()> {
125        self.conn.send(cmd).await
126    }
127
128    /// # Cancel safety
129    ///
130    /// This method is cancel safe.
131    async fn recv(&mut self) -> anyhow::Result<Option<In>> {
132        // `Connection::recv` is documented to be cancel safe.
133        self.conn.recv().await.map(Some)
134    }
135}
136
137/// Spawn a CTP server that serves connections at the given address.
138pub async fn serve<In, Out, H>(
139    address: SocketAddr,
140    version: Version,
141    server_fqdn: Option<String>,
142    idle_timeout: Duration,
143    handler_fn: impl Fn() -> H,
144    metrics: impl Metrics<Out, In>,
145) -> anyhow::Result<()>
146where
147    In: Message,
148    Out: Message,
149    H: GenericClient<In, Out> + 'static,
150{
151    // Keep a handle to the task serving the current connection, as well as a cancelation token, so
152    // we can cancel it when a new client connects.
153    //
154    // Note that we cannot simply abort the previous connection task because its future isn't known
155    // to be cancel safe. Instead we pass the connection tasks a cancelation token and wait for
156    // them to shut themselves down gracefully once the token gets dropped.
157    let mut connection_task: Option<(JoinHandle<()>, oneshot::Sender<()>)> = None;
158
159    let listener = Listener::bind(&address).await?;
160    info!(%address, "ctp: listening for client connections");
161
162    loop {
163        let (stream, peer) = listener.accept().await?;
164        info!(%peer, "ctp: accepted client connection");
165
166        // Cancel any existing connection before starting to serve the new one.
167        if let Some((task, token)) = connection_task.take() {
168            drop(token);
169            task.await;
170        }
171
172        let handler = handler_fn();
173        let version = version.clone();
174        let server_fqdn = server_fqdn.clone();
175        let metrics = metrics.clone();
176        let (cancel_tx, cancel_rx) = oneshot::channel();
177
178        let span = tracing::Span::current();
179        let handle = mz_ore::task::spawn(
180            || "ctp::connection",
181            async move {
182                let Err(error) = serve_connection(
183                    stream,
184                    handler,
185                    version,
186                    server_fqdn,
187                    idle_timeout,
188                    cancel_rx,
189                    metrics,
190                )
191                .await;
192                info!("ctp: connection failed: {error}");
193            }
194            .instrument(span),
195        );
196
197        connection_task = Some((handle, cancel_tx));
198    }
199}
200
201/// Serve a single CTP connection.
202async fn serve_connection<In, Out, H>(
203    stream: Stream,
204    mut handler: H,
205    version: Version,
206    server_fqdn: Option<String>,
207    timeout: Duration,
208    cancel_rx: oneshot::Receiver<()>,
209    metrics: impl Metrics<Out, In>,
210) -> anyhow::Result<Infallible>
211where
212    In: Message,
213    Out: Message,
214    H: GenericClient<In, Out>,
215{
216    let mut conn = Connection::start(stream, version, server_fqdn, timeout, metrics).await?;
217
218    let mut cancel_rx = cancel_rx;
219    loop {
220        tokio::select! {
221            // `Connection::recv` is documented to be cancel safe.
222            inbound = conn.recv() => {
223                let msg = inbound?;
224                handler.send(msg).await?;
225            },
226            // `GenericClient::recv` is documented to be cancel safe.
227            outbound = handler.recv() => match outbound? {
228                Some(msg) => conn.send(msg).await?,
229                None => bail!("client disconnected"),
230            },
231            _ = &mut cancel_rx => bail!("connection canceled"),
232        }
233    }
234}
235
236/// An active CTP connection.
237///
238/// This type encapsulates the core connection logic. It is used by both the client and the server
239/// implementation, with swapped `Out`/`In` types.
240///
241/// Each connection spawns two tasks:
242///
243///  * The send task is responsible for encoding and sending enqueued messages.
244///  * The recv task is responsible for receiving and decoding messages from the peer.
245///
246/// The separation into tasks provides some performance isolation between the sending and the
247/// receiving half of the connection.
248#[derive(Debug)]
249struct Connection<Out, In> {
250    /// Message sender connected to the send task.
251    msg_tx: mpsc::UnboundedSender<Out>,
252    /// Message receiver connected to the receive task.
253    msg_rx: mpsc::UnboundedReceiver<In>,
254    /// Receiver for errors encountered by connection tasks.
255    error_rx: watch::Receiver<String>,
256
257    /// Handles to connection tasks.
258    _tasks: [AbortOnDropHandle<()>; 2],
259}
260
261impl<Out: Message, In: Message> Connection<Out, In> {
262    /// The interval with which keepalives are emitted on idle connections.
263    const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(1);
264    /// The minimum acceptable idle timeout.
265    ///
266    /// We want this to be significantly greater than `KEEPALIVE_INTERVAL`, to avoid connections
267    /// getting canceled unnecessarily.
268    const MIN_TIMEOUT: Duration = Duration::from_secs(2);
269
270    /// Start a new connection wrapping the given stream.
271    async fn start(
272        stream: Stream,
273        version: Version,
274        server_fqdn: Option<String>,
275        mut timeout: Duration,
276        metrics: impl Metrics<Out, In>,
277    ) -> anyhow::Result<Self> {
278        if timeout < Self::MIN_TIMEOUT {
279            warn!(
280                ?timeout,
281                "ctp: configured timeout is less than minimum timeout",
282            );
283            timeout = Self::MIN_TIMEOUT;
284        }
285
286        let (reader, writer) = stream.split();
287
288        // Apply the timeout to all connection reads and writes.
289        let reader = TimedReader::new(reader, timeout);
290        let writer = TimedWriter::new(writer, timeout);
291        // Track byte count metrics for all connection reads and writes.
292        let mut reader = metrics::Reader::new(reader, metrics.clone());
293        let mut writer = metrics::Writer::new(writer, metrics.clone());
294
295        handshake(&mut reader, &mut writer, version, server_fqdn).await?;
296
297        let (out_tx, out_rx) = mpsc::unbounded_channel();
298        let (in_tx, in_rx) = mpsc::unbounded_channel();
299        // Initialize the error channel with a default error to return if none of the tasks
300        // produced an error.
301        let (error_tx, error_rx) = watch::channel("connection closed".into());
302
303        let span = tracing::Span::current();
304        let send_task = mz_ore::task::spawn(
305            || "ctp::send",
306            Self::run_send_task(writer, out_rx, error_tx.clone(), metrics.clone())
307                .instrument(span.clone()),
308        );
309        let recv_task = mz_ore::task::spawn(
310            || "ctp::recv",
311            Self::run_recv_task(reader, in_tx, error_tx, metrics).instrument(span),
312        );
313
314        Ok(Self {
315            msg_tx: out_tx,
316            msg_rx: in_rx,
317            error_rx,
318            _tasks: [send_task.abort_on_drop(), recv_task.abort_on_drop()],
319        })
320    }
321
322    /// Enqueue a message for sending.
323    async fn send(&mut self, msg: Out) -> anyhow::Result<()> {
324        match self.msg_tx.send(msg) {
325            Ok(()) => Ok(()),
326            Err(_) => bail!(self.collect_error().await),
327        }
328    }
329
330    /// Return a received message.
331    ///
332    /// # Cancel safety
333    ///
334    /// This method is cancel safe.
335    async fn recv(&mut self) -> anyhow::Result<In> {
336        // `mpcs::Receiver::recv` is documented to be cancel safe.
337        match self.msg_rx.recv().await {
338            Some(msg) => Ok(msg),
339            None => bail!(self.collect_error().await),
340        }
341    }
342
343    /// Return a connection error.
344    async fn collect_error(&mut self) -> String {
345        // Wait for the first error to be reported, or for all connection tasks to shut down.
346        let _ = self.error_rx.changed().await;
347        // Mark the current value as unseen, so the next `collect_error` call can return
348        // immediately.
349        self.error_rx.mark_changed();
350
351        self.error_rx.borrow().clone()
352    }
353
354    /// Run a connection's send task.
355    async fn run_send_task<W: AsyncWrite + Unpin>(
356        mut writer: W,
357        mut msg_rx: mpsc::UnboundedReceiver<Out>,
358        error_tx: watch::Sender<String>,
359        mut metrics: impl Metrics<Out, In>,
360    ) {
361        loop {
362            let msg = tokio::select! {
363                // `mpsc::UnboundedReceiver::recv` is cancel safe.
364                msg = msg_rx.recv() => match msg {
365                    Some(msg) => {
366                        trace!(?msg, "ctp: sending message");
367                        Some(msg)
368                    }
369                    None => break,
370                },
371                // `tokio::time::sleep` is cancel safe.
372                _ = tokio::time::sleep(Self::KEEPALIVE_INTERVAL) => {
373                    trace!("ctp: sending keepalive");
374                    None
375                },
376            };
377
378            if let Err(error) = write_message(&mut writer, msg.as_ref()).await {
379                debug!("ctp: send error: {error}");
380                let _ = error_tx.send(error.to_string());
381                break;
382            };
383
384            if let Some(msg) = &msg {
385                metrics.message_sent(msg);
386            }
387        }
388    }
389
390    /// Run a connection's recv task.
391    async fn run_recv_task<R: AsyncRead + Unpin>(
392        mut reader: R,
393        msg_tx: mpsc::UnboundedSender<In>,
394        error_tx: watch::Sender<String>,
395        mut metrics: impl Metrics<Out, In>,
396    ) {
397        loop {
398            match read_message(&mut reader).await {
399                Ok(msg) => {
400                    trace!(?msg, "ctp: received message");
401                    metrics.message_received(&msg);
402
403                    if msg_tx.send(msg).is_err() {
404                        break;
405                    }
406                }
407                Err(error) => {
408                    debug!("ctp: recv error: {error}");
409                    let _ = error_tx.send(error.to_string());
410                    break;
411                }
412            };
413        }
414    }
415}
416
417/// Perform the CTP handshake.
418///
419/// To perform the handshake, each endpoint sends the protocol magic number, followed by a
420/// `Hello` message. The `Hello` message contains information about the originating endpoint that
421/// is used by the receiver to validate compatibility with its peer. Only if both endpoints
422/// determine that they are compatible does the handshake succeed.
423async fn handshake<R, W>(
424    mut reader: R,
425    mut writer: W,
426    version: Version,
427    server_fqdn: Option<String>,
428) -> anyhow::Result<()>
429where
430    R: AsyncRead + Unpin,
431    W: AsyncWrite + Unpin,
432{
433    /// A randomly chosen magic number identifying CTP connections.
434    const MAGIC: u64 = 0x477574656e546167;
435
436    writer.write_u64(MAGIC).await?;
437
438    let hello = Hello {
439        version: version.clone(),
440        server_fqdn: server_fqdn.clone(),
441    };
442    write_message(&mut writer, Some(&hello)).await?;
443
444    let peer_magic = reader.read_u64().await?;
445    if peer_magic != MAGIC {
446        bail!("invalid protocol magic: {peer_magic:#x}");
447    }
448
449    let Hello {
450        version: peer_version,
451        server_fqdn: peer_server_fqdn,
452    } = read_message(&mut reader).await?;
453
454    if peer_version != version {
455        bail!("version mismatch: {peer_version} != {version}");
456    }
457    if let (Some(other), Some(mine)) = (&peer_server_fqdn, &server_fqdn) {
458        if other != mine {
459            bail!("server FQDN mismatch: {other} != {mine}");
460        }
461    }
462
463    Ok(())
464}
465
466/// A message for exchanging compatibility information during the CTP handshake.
467#[derive(Debug, Serialize, Deserialize)]
468struct Hello {
469    /// The version of the originating endpoint.
470    version: Version,
471    /// The FQDN of the server endpoint.
472    server_fqdn: Option<String>,
473}
474
475/// Write a message into the given writer.
476///
477/// The message can be `None`, in which case an empty message is written. This is used to implement
478/// keepalives. At the receiver, empty messages are ignored, but they do reset the read timeout.
479async fn write_message<W, M>(mut writer: W, msg: Option<&M>) -> anyhow::Result<()>
480where
481    W: AsyncWrite + Unpin,
482    M: Message,
483{
484    let bytes = match msg {
485        Some(msg) => &*wire_encode(msg)?,
486        None => &[],
487    };
488
489    let len = bytes.len().cast_into();
490    writer.write_u64(len).await?;
491    writer.write_all(bytes).await?;
492
493    Ok(())
494}
495
496/// Read a message from the given reader.
497async fn read_message<R, M>(mut reader: R) -> anyhow::Result<M>
498where
499    R: AsyncRead + Unpin,
500    M: Message,
501{
502    // Skip over any empty messages (i.e. keepalives).
503    let mut len = 0;
504    while len == 0 {
505        len = reader.read_u64().await?;
506    }
507
508    let mut bytes = vec![0; len.cast_into()];
509    reader.read_exact(&mut bytes).await?;
510
511    wire_decode(&bytes)
512}
513
514/// Encode a message for wire transport.
515fn wire_encode<M: Message>(msg: &M) -> anyhow::Result<Vec<u8>> {
516    let bytes = bincode::DefaultOptions::new().serialize(msg)?;
517    Ok(bytes)
518}
519
520/// Decode a wire frame back into a message.
521fn wire_decode<M: Message>(bytes: &[u8]) -> anyhow::Result<M> {
522    let msg = bincode::DefaultOptions::new().deserialize(bytes)?;
523    Ok(msg)
524}