1mod metrics;
24
25use std::convert::Infallible;
26use std::fmt::Debug;
27use std::time::Duration;
28
29use anyhow::{anyhow, bail};
30use async_trait::async_trait;
31use bincode::Options;
32use futures::future;
33use mz_ore::cast::CastInto;
34use mz_ore::netio::{Listener, SocketAddr, Stream, TimedReader, TimedWriter};
35use mz_ore::task::{AbortOnDropHandle, JoinHandle, JoinHandleExt};
36use semver::Version;
37use serde::de::DeserializeOwned;
38use serde::{Deserialize, Serialize};
39use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
40use tokio::sync::{mpsc, oneshot, watch};
41use tracing::{debug, info, trace, warn};
42
43use crate::client::{GenericClient, Partitionable, Partitioned};
44
45pub use metrics::{Metrics, NoopMetrics};
46
47pub trait Message: Debug + Send + Sync + Serialize + DeserializeOwned + 'static {}
49impl<T: Debug + Send + Sync + Serialize + DeserializeOwned + 'static> Message for T {}
50
51#[derive(Debug)]
53pub struct Client<Out, In> {
54 conn: Connection<Out, In>,
55}
56
57impl<Out: Message, In: Message> Client<Out, In> {
58 pub async fn connect(
63 address: &str,
64 version: Version,
65 connect_timeout: Duration,
66 idle_timeout: Duration,
67 metrics: impl Metrics<Out, In>,
68 ) -> anyhow::Result<Self> {
69 let dest_host = host_from_address(address);
70 let stream = mz_ore::future::timeout(connect_timeout, Stream::connect(address)).await?;
71 info!(%address, "ctp: connected to server");
72
73 let conn = Connection::start(stream, version, dest_host, idle_timeout, metrics).await?;
74 Ok(Self { conn })
75 }
76}
77
78fn host_from_address(address: &str) -> Option<String> {
83 let mut p = address.split(':');
84 let (host, port) = match (p.next(), p.next(), p.next(), p.next()) {
85 (Some(host), Some(port), None, None) => (host, port),
86 (Some(_protocol), Some(host), Some(port), None) => (host, port),
87 _ => return None,
88 };
89
90 let _: u16 = port.parse().ok()?;
91 Some(host.into())
92}
93
94impl<Out, In> Client<Out, In>
95where
96 Out: Message,
97 In: Message,
98 (Out, In): Partitionable<Out, In>,
99{
100 pub async fn connect_partitioned(
102 addresses: Vec<String>,
103 version: Version,
104 connect_timeout: Duration,
105 idle_timeout: Duration,
106 metrics: impl Metrics<Out, In>,
107 ) -> anyhow::Result<Partitioned<Self, Out, In>> {
108 let connects = addresses.iter().map(|addr| {
109 Self::connect(
110 addr,
111 version.clone(),
112 connect_timeout,
113 idle_timeout,
114 metrics.clone(),
115 )
116 });
117 let clients = future::try_join_all(connects).await?;
118 Ok(Partitioned::new(clients))
119 }
120}
121
122#[async_trait]
123impl<Out: Message, In: Message> GenericClient<Out, In> for Client<Out, In> {
124 async fn send(&mut self, cmd: Out) -> anyhow::Result<()> {
125 self.conn.send(cmd).await
126 }
127
128 async fn recv(&mut self) -> anyhow::Result<Option<In>> {
132 self.conn.recv().await.map(Some)
134 }
135}
136
137pub async fn serve<In, Out, H>(
139 address: SocketAddr,
140 version: Version,
141 server_fqdn: Option<String>,
142 idle_timeout: Duration,
143 handler_fn: impl Fn() -> H,
144 metrics: impl Metrics<Out, In>,
145) -> anyhow::Result<()>
146where
147 In: Message,
148 Out: Message,
149 H: GenericClient<In, Out> + 'static,
150{
151 let mut connection_task: Option<(JoinHandle<()>, oneshot::Sender<()>)> = None;
158
159 let listener = Listener::bind(&address).await?;
160 info!(%address, "ctp: listening for client connections");
161
162 loop {
163 let (stream, peer) = listener.accept().await?;
164 info!(%peer, "ctp: accepted client connection");
165
166 if let Some((task, token)) = connection_task.take() {
168 drop(token);
169 task.wait_and_assert_finished().await;
170 }
171
172 let handler = handler_fn();
173 let version = version.clone();
174 let server_fqdn = server_fqdn.clone();
175 let metrics = metrics.clone();
176 let (cancel_tx, cancel_rx) = oneshot::channel();
177
178 let handle = mz_ore::task::spawn(|| "ctp::connection", async move {
179 let Err(error) = serve_connection(
180 stream,
181 handler,
182 version,
183 server_fqdn,
184 idle_timeout,
185 cancel_rx,
186 metrics,
187 )
188 .await;
189 info!("ctp: connection failed: {error}");
190 });
191
192 connection_task = Some((handle, cancel_tx));
193 }
194}
195
196async fn serve_connection<In, Out, H>(
198 stream: Stream,
199 mut handler: H,
200 version: Version,
201 server_fqdn: Option<String>,
202 timeout: Duration,
203 cancel_rx: oneshot::Receiver<()>,
204 metrics: impl Metrics<Out, In>,
205) -> anyhow::Result<Infallible>
206where
207 In: Message,
208 Out: Message,
209 H: GenericClient<In, Out>,
210{
211 let mut conn = Connection::start(stream, version, server_fqdn, timeout, metrics).await?;
212
213 let mut cancel_rx = cancel_rx;
214 loop {
215 tokio::select! {
216 inbound = conn.recv() => {
218 let msg = inbound?;
219 handler.send(msg).await?;
220 },
221 outbound = handler.recv() => match outbound? {
223 Some(msg) => conn.send(msg).await?,
224 None => bail!("client disconnected"),
225 },
226 _ = &mut cancel_rx => bail!("connection canceled"),
227 }
228 }
229}
230
231#[derive(Debug)]
244struct Connection<Out, In> {
245 msg_tx: mpsc::Sender<Out>,
247 msg_rx: mpsc::Receiver<In>,
249 error_rx: watch::Receiver<String>,
251
252 _tasks: [AbortOnDropHandle<()>; 2],
254}
255
256impl<Out: Message, In: Message> Connection<Out, In> {
257 const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(1);
259 const MIN_TIMEOUT: Duration = Duration::from_secs(2);
264
265 async fn start(
267 stream: Stream,
268 version: Version,
269 server_fqdn: Option<String>,
270 mut timeout: Duration,
271 metrics: impl Metrics<Out, In>,
272 ) -> anyhow::Result<Self> {
273 if timeout < Self::MIN_TIMEOUT {
274 warn!(
275 ?timeout,
276 "ctp: configured timeout is less than minimum timeout",
277 );
278 timeout = Self::MIN_TIMEOUT;
279 }
280
281 let (reader, writer) = stream.split();
282
283 let reader = TimedReader::new(reader, timeout);
285 let writer = TimedWriter::new(writer, timeout);
286 let mut reader = metrics::Reader::new(reader, metrics.clone());
288 let mut writer = metrics::Writer::new(writer, metrics.clone());
289
290 handshake(&mut reader, &mut writer, version, server_fqdn).await?;
291
292 let (out_tx, out_rx) = mpsc::channel(1024);
293 let (in_tx, in_rx) = mpsc::channel(1024);
294 let (error_tx, error_rx) = watch::channel("connection closed".into());
297
298 let send_task = mz_ore::task::spawn(
299 || "ctp::send",
300 Self::run_send_task(writer, out_rx, error_tx.clone(), metrics.clone()),
301 );
302 let recv_task = mz_ore::task::spawn(
303 || "ctp::recv",
304 Self::run_recv_task(reader, in_tx, error_tx, metrics),
305 );
306
307 Ok(Self {
308 msg_tx: out_tx,
309 msg_rx: in_rx,
310 error_rx,
311 _tasks: [send_task.abort_on_drop(), recv_task.abort_on_drop()],
312 })
313 }
314
315 async fn send(&mut self, msg: Out) -> anyhow::Result<()> {
317 match self.msg_tx.send(msg).await {
318 Ok(()) => Ok(()),
319 Err(_) => bail!(self.collect_error().await),
320 }
321 }
322
323 async fn recv(&mut self) -> anyhow::Result<In> {
329 match self.msg_rx.recv().await {
331 Some(msg) => Ok(msg),
332 None => bail!(self.collect_error().await),
333 }
334 }
335
336 async fn collect_error(&mut self) -> String {
338 let _ = self.error_rx.changed().await;
340 self.error_rx.mark_changed();
343
344 self.error_rx.borrow().clone()
345 }
346
347 async fn run_send_task<W: AsyncWrite + Unpin>(
349 mut writer: W,
350 mut msg_rx: mpsc::Receiver<Out>,
351 error_tx: watch::Sender<String>,
352 mut metrics: impl Metrics<Out, In>,
353 ) {
354 loop {
355 let msg = tokio::select! {
356 msg = msg_rx.recv() => match msg {
358 Some(msg) => {
359 trace!(?msg, "ctp: sending message");
360 Some(msg)
361 }
362 None => break,
363 },
364 _ = tokio::time::sleep(Self::KEEPALIVE_INTERVAL) => {
366 trace!("ctp: sending keepalive");
367 None
368 },
369 };
370
371 if let Err(error) = write_message(&mut writer, msg.as_ref()).await {
372 debug!("ctp: send error: {error}");
373 let _ = error_tx.send(error.to_string());
374 break;
375 };
376
377 if let Some(msg) = &msg {
378 metrics.message_sent(msg);
379 }
380 }
381 }
382
383 async fn run_recv_task<R: AsyncRead + Unpin>(
385 mut reader: R,
386 msg_tx: mpsc::Sender<In>,
387 error_tx: watch::Sender<String>,
388 mut metrics: impl Metrics<Out, In>,
389 ) {
390 loop {
391 match read_message(&mut reader).await {
392 Ok(msg) => {
393 trace!(?msg, "ctp: received message");
394 metrics.message_received(&msg);
395
396 if msg_tx.send(msg).await.is_err() {
397 break;
398 }
399 }
400 Err(error) => {
401 debug!("ctp: recv error: {error}");
402 let _ = error_tx.send(error.to_string());
403 break;
404 }
405 };
406 }
407 }
408}
409
410#[derive(Debug)]
412pub struct ChannelHandler<In, Out> {
413 tx: mpsc::UnboundedSender<In>,
414 rx: mpsc::UnboundedReceiver<Out>,
415}
416
417impl<In, Out> ChannelHandler<In, Out> {
418 pub fn new(tx: mpsc::UnboundedSender<In>, rx: mpsc::UnboundedReceiver<Out>) -> Self {
419 Self { tx, rx }
420 }
421}
422
423#[async_trait]
424impl<In: Message, Out: Message> GenericClient<In, Out> for ChannelHandler<In, Out> {
425 async fn send(&mut self, cmd: In) -> anyhow::Result<()> {
426 let result = self.tx.send(cmd);
427 result.map_err(|_| anyhow!("client channel disconnected"))
428 }
429
430 async fn recv(&mut self) -> anyhow::Result<Option<Out>> {
434 match self.rx.recv().await {
436 Some(resp) => Ok(Some(resp)),
437 None => bail!("client channel disconnected"),
438 }
439 }
440}
441
442async fn handshake<R, W>(
449 mut reader: R,
450 mut writer: W,
451 version: Version,
452 server_fqdn: Option<String>,
453) -> anyhow::Result<()>
454where
455 R: AsyncRead + Unpin,
456 W: AsyncWrite + Unpin,
457{
458 const MAGIC: u64 = 0x477574656e546167;
460
461 writer.write_u64(MAGIC).await?;
462
463 let hello = Hello {
464 version: version.clone(),
465 server_fqdn: server_fqdn.clone(),
466 };
467 write_message(&mut writer, Some(&hello)).await?;
468
469 let peer_magic = reader.read_u64().await?;
470 if peer_magic != MAGIC {
471 bail!("invalid protocol magic: {peer_magic:#x}");
472 }
473
474 let Hello {
475 version: peer_version,
476 server_fqdn: peer_server_fqdn,
477 } = read_message(&mut reader).await?;
478
479 if peer_version != version {
480 bail!("version mismatch: {peer_version} != {version}");
481 }
482 if let (Some(other), Some(mine)) = (&peer_server_fqdn, &server_fqdn) {
483 if other != mine {
484 bail!("server FQDN mismatch: {other} != {mine}");
485 }
486 }
487
488 Ok(())
489}
490
491#[derive(Debug, Serialize, Deserialize)]
493struct Hello {
494 version: Version,
496 server_fqdn: Option<String>,
498}
499
500async fn write_message<W, M>(mut writer: W, msg: Option<&M>) -> anyhow::Result<()>
505where
506 W: AsyncWrite + Unpin,
507 M: Message,
508{
509 let bytes = match msg {
510 Some(msg) => &*wire_encode(msg)?,
511 None => &[],
512 };
513
514 let len = bytes.len().cast_into();
515 writer.write_u64(len).await?;
516 writer.write_all(bytes).await?;
517
518 Ok(())
519}
520
521async fn read_message<R, M>(mut reader: R) -> anyhow::Result<M>
523where
524 R: AsyncRead + Unpin,
525 M: Message,
526{
527 let mut len = 0;
529 while len == 0 {
530 len = reader.read_u64().await?;
531 }
532
533 let mut bytes = vec![0; len.cast_into()];
534 reader.read_exact(&mut bytes).await?;
535
536 wire_decode(&bytes)
537}
538
539fn wire_encode<M: Message>(msg: &M) -> anyhow::Result<Vec<u8>> {
541 let bytes = bincode::DefaultOptions::new().serialize(msg)?;
542 Ok(bytes)
543}
544
545fn wire_decode<M: Message>(bytes: &[u8]) -> anyhow::Result<M> {
547 let msg = bincode::DefaultOptions::new().deserialize(bytes)?;
548 Ok(msg)
549}