1mod metrics;
24
25use std::convert::Infallible;
26use std::fmt::Debug;
27use std::time::Duration;
28
29use anyhow::bail;
30use async_trait::async_trait;
31use bincode::Options;
32use futures::future;
33use mz_ore::cast::CastInto;
34use mz_ore::netio::{Listener, SocketAddr, Stream, TimedReader, TimedWriter};
35use mz_ore::task::{AbortOnDropHandle, JoinHandle};
36use semver::Version;
37use serde::de::DeserializeOwned;
38use serde::{Deserialize, Serialize};
39use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
40use tokio::sync::{mpsc, oneshot, watch};
41use tracing::{Instrument, debug, info, trace, warn};
42
43use crate::client::{GenericClient, Partitionable, Partitioned};
44
45pub use metrics::{Metrics, NoopMetrics};
46
47pub trait Message: Debug + Send + Sync + Serialize + DeserializeOwned + 'static {}
49impl<T: Debug + Send + Sync + Serialize + DeserializeOwned + 'static> Message for T {}
50
51#[derive(Debug)]
53pub struct Client<Out, In> {
54 conn: Connection<Out, In>,
55}
56
57impl<Out: Message, In: Message> Client<Out, In> {
58 pub async fn connect(
63 address: &str,
64 version: Version,
65 connect_timeout: Duration,
66 idle_timeout: Duration,
67 metrics: impl Metrics<Out, In>,
68 ) -> anyhow::Result<Self> {
69 let dest_host = host_from_address(address);
70 let stream = mz_ore::future::timeout(connect_timeout, Stream::connect(address)).await?;
71 info!(%address, "ctp: connected to server");
72
73 let conn = Connection::start(stream, version, dest_host, idle_timeout, metrics).await?;
74 Ok(Self { conn })
75 }
76}
77
78fn host_from_address(address: &str) -> Option<String> {
83 let mut p = address.split(':');
84 let (host, port) = match (p.next(), p.next(), p.next(), p.next()) {
85 (Some(host), Some(port), None, None) => (host, port),
86 (Some(_protocol), Some(host), Some(port), None) => (host, port),
87 _ => return None,
88 };
89
90 let _: u16 = port.parse().ok()?;
91 Some(host.into())
92}
93
94impl<Out, In> Client<Out, In>
95where
96 Out: Message,
97 In: Message,
98 (Out, In): Partitionable<Out, In>,
99{
100 pub async fn connect_partitioned(
102 addresses: Vec<String>,
103 version: Version,
104 connect_timeout: Duration,
105 idle_timeout: Duration,
106 metrics: impl Metrics<Out, In>,
107 ) -> anyhow::Result<Partitioned<Self, Out, In>> {
108 let connects = addresses.iter().map(|addr| {
109 Self::connect(
110 addr,
111 version.clone(),
112 connect_timeout,
113 idle_timeout,
114 metrics.clone(),
115 )
116 });
117 let clients = future::try_join_all(connects).await?;
118 Ok(Partitioned::new(clients))
119 }
120}
121
122#[async_trait]
123impl<Out: Message, In: Message> GenericClient<Out, In> for Client<Out, In> {
124 async fn send(&mut self, cmd: Out) -> anyhow::Result<()> {
125 self.conn.send(cmd).await
126 }
127
128 async fn recv(&mut self) -> anyhow::Result<Option<In>> {
132 self.conn.recv().await.map(Some)
134 }
135}
136
137pub async fn serve<In, Out, H>(
139 address: SocketAddr,
140 version: Version,
141 server_fqdn: Option<String>,
142 idle_timeout: Duration,
143 handler_fn: impl Fn() -> H,
144 metrics: impl Metrics<Out, In>,
145) -> anyhow::Result<()>
146where
147 In: Message,
148 Out: Message,
149 H: GenericClient<In, Out> + 'static,
150{
151 let mut connection_task: Option<(JoinHandle<()>, oneshot::Sender<()>)> = None;
158
159 let listener = Listener::bind(&address).await?;
160 info!(%address, "ctp: listening for client connections");
161
162 loop {
163 let (stream, peer) = listener.accept().await?;
164 info!(%peer, "ctp: accepted client connection");
165
166 if let Some((task, token)) = connection_task.take() {
168 drop(token);
169 task.await;
170 }
171
172 let handler = handler_fn();
173 let version = version.clone();
174 let server_fqdn = server_fqdn.clone();
175 let metrics = metrics.clone();
176 let (cancel_tx, cancel_rx) = oneshot::channel();
177
178 let span = tracing::Span::current();
179 let handle = mz_ore::task::spawn(
180 || "ctp::connection",
181 async move {
182 let Err(error) = serve_connection(
183 stream,
184 handler,
185 version,
186 server_fqdn,
187 idle_timeout,
188 cancel_rx,
189 metrics,
190 )
191 .await;
192 info!("ctp: connection failed: {error}");
193 }
194 .instrument(span),
195 );
196
197 connection_task = Some((handle, cancel_tx));
198 }
199}
200
201async fn serve_connection<In, Out, H>(
203 stream: Stream,
204 mut handler: H,
205 version: Version,
206 server_fqdn: Option<String>,
207 timeout: Duration,
208 cancel_rx: oneshot::Receiver<()>,
209 metrics: impl Metrics<Out, In>,
210) -> anyhow::Result<Infallible>
211where
212 In: Message,
213 Out: Message,
214 H: GenericClient<In, Out>,
215{
216 let mut conn = Connection::start(stream, version, server_fqdn, timeout, metrics).await?;
217
218 let mut cancel_rx = cancel_rx;
219 loop {
220 tokio::select! {
221 inbound = conn.recv() => {
223 let msg = inbound?;
224 handler.send(msg).await?;
225 },
226 outbound = handler.recv() => match outbound? {
228 Some(msg) => conn.send(msg).await?,
229 None => bail!("client disconnected"),
230 },
231 _ = &mut cancel_rx => bail!("connection canceled"),
232 }
233 }
234}
235
236#[derive(Debug)]
249struct Connection<Out, In> {
250 msg_tx: mpsc::UnboundedSender<Out>,
252 msg_rx: mpsc::UnboundedReceiver<In>,
254 error_rx: watch::Receiver<String>,
256
257 _tasks: [AbortOnDropHandle<()>; 2],
259}
260
261impl<Out: Message, In: Message> Connection<Out, In> {
262 const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(1);
264 const MIN_TIMEOUT: Duration = Duration::from_secs(2);
269
270 async fn start(
272 stream: Stream,
273 version: Version,
274 server_fqdn: Option<String>,
275 mut timeout: Duration,
276 metrics: impl Metrics<Out, In>,
277 ) -> anyhow::Result<Self> {
278 if timeout < Self::MIN_TIMEOUT {
279 warn!(
280 ?timeout,
281 "ctp: configured timeout is less than minimum timeout",
282 );
283 timeout = Self::MIN_TIMEOUT;
284 }
285
286 let (reader, writer) = stream.split();
287
288 let reader = TimedReader::new(reader, timeout);
290 let writer = TimedWriter::new(writer, timeout);
291 let mut reader = metrics::Reader::new(reader, metrics.clone());
293 let mut writer = metrics::Writer::new(writer, metrics.clone());
294
295 handshake(&mut reader, &mut writer, version, server_fqdn).await?;
296
297 let (out_tx, out_rx) = mpsc::unbounded_channel();
298 let (in_tx, in_rx) = mpsc::unbounded_channel();
299 let (error_tx, error_rx) = watch::channel("connection closed".into());
302
303 let span = tracing::Span::current();
304 let send_task = mz_ore::task::spawn(
305 || "ctp::send",
306 Self::run_send_task(writer, out_rx, error_tx.clone(), metrics.clone())
307 .instrument(span.clone()),
308 );
309 let recv_task = mz_ore::task::spawn(
310 || "ctp::recv",
311 Self::run_recv_task(reader, in_tx, error_tx, metrics).instrument(span),
312 );
313
314 Ok(Self {
315 msg_tx: out_tx,
316 msg_rx: in_rx,
317 error_rx,
318 _tasks: [send_task.abort_on_drop(), recv_task.abort_on_drop()],
319 })
320 }
321
322 async fn send(&mut self, msg: Out) -> anyhow::Result<()> {
324 match self.msg_tx.send(msg) {
325 Ok(()) => Ok(()),
326 Err(_) => bail!(self.collect_error().await),
327 }
328 }
329
330 async fn recv(&mut self) -> anyhow::Result<In> {
336 match self.msg_rx.recv().await {
338 Some(msg) => Ok(msg),
339 None => bail!(self.collect_error().await),
340 }
341 }
342
343 async fn collect_error(&mut self) -> String {
345 let _ = self.error_rx.changed().await;
347 self.error_rx.mark_changed();
350
351 self.error_rx.borrow().clone()
352 }
353
354 async fn run_send_task<W: AsyncWrite + Unpin>(
356 mut writer: W,
357 mut msg_rx: mpsc::UnboundedReceiver<Out>,
358 error_tx: watch::Sender<String>,
359 mut metrics: impl Metrics<Out, In>,
360 ) {
361 loop {
362 let msg = tokio::select! {
363 msg = msg_rx.recv() => match msg {
365 Some(msg) => {
366 trace!(?msg, "ctp: sending message");
367 Some(msg)
368 }
369 None => break,
370 },
371 _ = tokio::time::sleep(Self::KEEPALIVE_INTERVAL) => {
373 trace!("ctp: sending keepalive");
374 None
375 },
376 };
377
378 if let Err(error) = write_message(&mut writer, msg.as_ref()).await {
379 debug!("ctp: send error: {error}");
380 let _ = error_tx.send(error.to_string());
381 break;
382 };
383
384 if let Some(msg) = &msg {
385 metrics.message_sent(msg);
386 }
387 }
388 }
389
390 async fn run_recv_task<R: AsyncRead + Unpin>(
392 mut reader: R,
393 msg_tx: mpsc::UnboundedSender<In>,
394 error_tx: watch::Sender<String>,
395 mut metrics: impl Metrics<Out, In>,
396 ) {
397 loop {
398 match read_message(&mut reader).await {
399 Ok(msg) => {
400 trace!(?msg, "ctp: received message");
401 metrics.message_received(&msg);
402
403 if msg_tx.send(msg).is_err() {
404 break;
405 }
406 }
407 Err(error) => {
408 debug!("ctp: recv error: {error}");
409 let _ = error_tx.send(error.to_string());
410 break;
411 }
412 };
413 }
414 }
415}
416
417async fn handshake<R, W>(
424 mut reader: R,
425 mut writer: W,
426 version: Version,
427 server_fqdn: Option<String>,
428) -> anyhow::Result<()>
429where
430 R: AsyncRead + Unpin,
431 W: AsyncWrite + Unpin,
432{
433 const MAGIC: u64 = 0x477574656e546167;
435
436 writer.write_u64(MAGIC).await?;
437
438 let hello = Hello {
439 version: version.clone(),
440 server_fqdn: server_fqdn.clone(),
441 };
442 write_message(&mut writer, Some(&hello)).await?;
443
444 let peer_magic = reader.read_u64().await?;
445 if peer_magic != MAGIC {
446 bail!("invalid protocol magic: {peer_magic:#x}");
447 }
448
449 let Hello {
450 version: peer_version,
451 server_fqdn: peer_server_fqdn,
452 } = read_message(&mut reader).await?;
453
454 if peer_version != version {
455 bail!("version mismatch: {peer_version} != {version}");
456 }
457 if let (Some(other), Some(mine)) = (&peer_server_fqdn, &server_fqdn) {
458 if other != mine {
459 bail!("server FQDN mismatch: {other} != {mine}");
460 }
461 }
462
463 Ok(())
464}
465
466#[derive(Debug, Serialize, Deserialize)]
468struct Hello {
469 version: Version,
471 server_fqdn: Option<String>,
473}
474
475async fn write_message<W, M>(mut writer: W, msg: Option<&M>) -> anyhow::Result<()>
480where
481 W: AsyncWrite + Unpin,
482 M: Message,
483{
484 let bytes = match msg {
485 Some(msg) => &*wire_encode(msg)?,
486 None => &[],
487 };
488
489 let len = bytes.len().cast_into();
490 writer.write_u64(len).await?;
491 writer.write_all(bytes).await?;
492
493 Ok(())
494}
495
496async fn read_message<R, M>(mut reader: R) -> anyhow::Result<M>
498where
499 R: AsyncRead + Unpin,
500 M: Message,
501{
502 let mut len = 0;
504 while len == 0 {
505 len = reader.read_u64().await?;
506 }
507
508 let mut bytes = vec![0; len.cast_into()];
509 reader.read_exact(&mut bytes).await?;
510
511 wire_decode(&bytes)
512}
513
514fn wire_encode<M: Message>(msg: &M) -> anyhow::Result<Vec<u8>> {
516 let bytes = bincode::DefaultOptions::new().serialize(msg)?;
517 Ok(bytes)
518}
519
520fn wire_decode<M: Message>(bytes: &[u8]) -> anyhow::Result<M> {
522 let msg = bincode::DefaultOptions::new().deserialize(bytes)?;
523 Ok(msg)
524}