Skip to main content

mz_persist_client/
rpc.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! gRPC-based implementations of Persist PubSub client and server.
11
12use std::collections::BTreeMap;
13use std::collections::btree_map::Entry;
14use std::fmt::{Debug, Formatter};
15use std::net::SocketAddr;
16use std::pin::Pin;
17use std::str::FromStr;
18use std::sync::atomic::{AtomicUsize, Ordering};
19use std::sync::{Arc, Mutex, RwLock, Weak};
20use std::time::{Duration, Instant, SystemTime};
21
22use anyhow::{Error, anyhow};
23use async_trait::async_trait;
24use bytes::Bytes;
25use futures::Stream;
26use futures_util::StreamExt;
27use mz_dyncfg::Config;
28use mz_ore::cast::CastFrom;
29use mz_ore::collections::{HashMap, HashSet};
30use mz_ore::metrics::MetricsRegistry;
31use mz_ore::retry::RetryResult;
32use mz_ore::task::JoinHandle;
33use mz_persist::location::VersionedData;
34use mz_proto::{ProtoType, RustType};
35use prost::Message;
36use tokio::sync::mpsc::Sender;
37use tokio::sync::mpsc::error::TrySendError;
38use tokio_stream::wrappers::errors::BroadcastStreamRecvError;
39use tokio_stream::wrappers::{BroadcastStream, ReceiverStream};
40use tonic::metadata::{AsciiMetadataKey, AsciiMetadataValue, MetadataMap};
41use tonic::transport::Endpoint;
42use tonic::{Extensions, Request, Response, Status, Streaming};
43use tracing::{Instrument, debug, error, info, info_span, warn};
44
45use crate::ShardId;
46use crate::cache::{DynState, StateCache};
47use crate::cfg::PersistConfig;
48use crate::internal::metrics::{PubSubClientCallMetrics, PubSubServerMetrics};
49use crate::internal::service::proto_persist_pub_sub_client::ProtoPersistPubSubClient;
50use crate::internal::service::proto_persist_pub_sub_server::ProtoPersistPubSubServer;
51use crate::internal::service::{
52    ProtoPubSubMessage, ProtoPushDiff, ProtoSubscribe, ProtoUnsubscribe,
53    proto_persist_pub_sub_server, proto_pub_sub_message,
54};
55use crate::metrics::Metrics;
56
57/// Determines whether PubSub clients should connect to the PubSub server.
58pub(crate) const PUBSUB_CLIENT_ENABLED: Config<bool> = Config::new(
59    "persist_pubsub_client_enabled",
60    true,
61    "Whether to connect to the Persist PubSub service.",
62);
63
64/// For connected clients, determines whether to push state diffs to the PubSub
65/// server. For the server, determines whether to broadcast state diffs to
66/// subscribed clients.
67pub(crate) const PUBSUB_PUSH_DIFF_ENABLED: Config<bool> = Config::new(
68    "persist_pubsub_push_diff_enabled",
69    true,
70    "Whether to push state diffs to Persist PubSub.",
71);
72
73/// For connected clients, determines whether to push state diffs to the PubSub
74/// server. For the server, determines whether to broadcast state diffs to
75/// subscribed clients.
76pub(crate) const PUBSUB_SAME_PROCESS_DELEGATE_ENABLED: Config<bool> = Config::new(
77    "persist_pubsub_same_process_delegate_enabled",
78    true,
79    "Whether to push state diffs to Persist PubSub on the same process.",
80);
81
82/// Timeout per connection attempt to Persist PubSub service.
83pub(crate) const PUBSUB_CONNECT_ATTEMPT_TIMEOUT: Config<Duration> = Config::new(
84    "persist_pubsub_connect_attempt_timeout",
85    Duration::from_secs(5),
86    "Timeout per connection attempt to Persist PubSub service.",
87);
88
89/// Timeout per request attempt to Persist PubSub service.
90pub(crate) const PUBSUB_REQUEST_TIMEOUT: Config<Duration> = Config::new(
91    "persist_pubsub_request_timeout",
92    Duration::from_secs(5),
93    "Timeout per request attempt to Persist PubSub service.",
94);
95
96/// Maximum backoff when retrying connection establishment to Persist PubSub service.
97pub(crate) const PUBSUB_CONNECT_MAX_BACKOFF: Config<Duration> = Config::new(
98    "persist_pubsub_connect_max_backoff",
99    Duration::from_secs(60),
100    "Maximum backoff when retrying connection establishment to Persist PubSub service.",
101);
102
103/// Size of channel used to buffer send messages to PubSub service.
104pub(crate) const PUBSUB_CLIENT_SENDER_CHANNEL_SIZE: Config<usize> = Config::new(
105    "persist_pubsub_client_sender_channel_size",
106    25,
107    "Size of channel used to buffer send messages to PubSub service.",
108);
109
110/// Size of channel used to buffer received messages from PubSub service.
111pub(crate) const PUBSUB_CLIENT_RECEIVER_CHANNEL_SIZE: Config<usize> = Config::new(
112    "persist_pubsub_client_receiver_channel_size",
113    25,
114    "Size of channel used to buffer received messages from PubSub service.",
115);
116
117/// Size of channel used per connection to buffer broadcasted messages from PubSub server.
118pub(crate) const PUBSUB_SERVER_CONNECTION_CHANNEL_SIZE: Config<usize> = Config::new(
119    "persist_pubsub_server_connection_channel_size",
120    25,
121    "Size of channel used per connection to buffer broadcasted messages from PubSub server.",
122);
123
124/// Size of channel used by the state cache to broadcast shard state references.
125pub(crate) const PUBSUB_STATE_CACHE_SHARD_REF_CHANNEL_SIZE: Config<usize> = Config::new(
126    "persist_pubsub_state_cache_shard_ref_channel_size",
127    25,
128    "Size of channel used by the state cache to broadcast shard state references.",
129);
130
131/// Backoff after an established connection to Persist PubSub service fails.
132pub(crate) const PUBSUB_RECONNECT_BACKOFF: Config<Duration> = Config::new(
133    "persist_pubsub_reconnect_backoff",
134    Duration::from_secs(5),
135    "Backoff after an established connection to Persist PubSub service fails.",
136);
137
138/// Top-level Trait to create a PubSubClient.
139///
140/// Returns a [PubSubClientConnection] with a [PubSubSender] for issuing RPCs to the PubSub
141/// server, and a [PubSubReceiver] that receives messages, such as state diffs.
142pub trait PersistPubSubClient {
143    /// Receive handles with which to push and subscribe to diffs.
144    fn connect(
145        pubsub_config: PersistPubSubClientConfig,
146        metrics: Arc<Metrics>,
147    ) -> PubSubClientConnection;
148}
149
150/// Wrapper type for a matching [PubSubSender] and [PubSubReceiver] client pair.
151#[derive(Debug)]
152pub struct PubSubClientConnection {
153    /// The sender client to Persist PubSub.
154    pub sender: Arc<dyn PubSubSender>,
155    /// The receiver client to Persist PubSub.
156    pub receiver: Box<dyn PubSubReceiver>,
157}
158
159impl PubSubClientConnection {
160    /// Creates a new [PubSubClientConnection] from a matching [PubSubSender] and [PubSubReceiver].
161    pub fn new(sender: Arc<dyn PubSubSender>, receiver: Box<dyn PubSubReceiver>) -> Self {
162        Self { sender, receiver }
163    }
164
165    /// Creates a no-op [PubSubClientConnection] that neither sends nor receives messages.
166    pub fn noop() -> Self {
167        Self {
168            sender: Arc::new(NoopPubSubSender),
169            receiver: Box::new(futures::stream::empty()),
170        }
171    }
172}
173
174/// The public send-side client to Persist PubSub.
175pub trait PubSubSender: std::fmt::Debug + Send + Sync {
176    /// Push a diff to subscribers.
177    fn push_diff(&self, shard_id: &ShardId, diff: &VersionedData);
178
179    /// Subscribe the corresponding [PubSubReceiver] to diffs for the given shard.
180    ///
181    /// Returns a token that, when dropped, will unsubscribe the client from the
182    /// shard.
183    ///
184    /// If the client is already subscribed to the shard, repeated calls will make
185    /// no further calls to the server and instead return clones of the `Arc<ShardSubscriptionToken>`.
186    fn subscribe(self: Arc<Self>, shard_id: &ShardId) -> Arc<ShardSubscriptionToken>;
187}
188
189/// The internal send-side client trait to Persist PubSub, responsible for issuing RPCs
190/// to the PubSub service. This trait is separated out from [PubSubSender] to keep the
191/// client implementations straightforward, while offering a more ergonomic public API
192/// in [PubSubSender].
193trait PubSubSenderInternal: Debug + Send + Sync {
194    /// Push a diff to subscribers.
195    fn push_diff(&self, shard_id: &ShardId, diff: &VersionedData);
196
197    /// Subscribe the corresponding [PubSubReceiver] to diffs for the given shard.
198    ///
199    /// This call is idempotent and is a no-op for an already subscribed shard.
200    fn subscribe(&self, shard_id: &ShardId);
201
202    /// Unsubscribe the corresponding [PubSubReceiver] from diffs for the given shard.
203    ///
204    /// This call is idempotent and is a no-op for already unsubscribed shards.
205    fn unsubscribe(&self, shard_id: &ShardId);
206}
207
208/// The receive-side client to Persist PubSub.
209///
210/// Returns diffs (and maybe in the future, blobs) for any shards subscribed to
211/// by the corresponding `PubSubSender`.
212pub trait PubSubReceiver:
213    Stream<Item = ProtoPubSubMessage> + Send + Unpin + std::fmt::Debug
214{
215}
216
217impl<T> PubSubReceiver for T where
218    T: Stream<Item = ProtoPubSubMessage> + Send + Unpin + std::fmt::Debug
219{
220}
221
222/// A token corresponding to a subscription to diffs for a particular shard.
223///
224/// When dropped, the client that originated the token will be unsubscribed
225/// from further diffs to the shard.
226pub struct ShardSubscriptionToken {
227    pub(crate) shard_id: ShardId,
228    sender: Arc<dyn PubSubSenderInternal>,
229}
230
231impl Debug for ShardSubscriptionToken {
232    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
233        let ShardSubscriptionToken {
234            shard_id,
235            sender: _sender,
236        } = self;
237        write!(f, "ShardSubscriptionToken({})", shard_id)
238    }
239}
240
241impl Drop for ShardSubscriptionToken {
242    fn drop(&mut self) {
243        self.sender.unsubscribe(&self.shard_id);
244    }
245}
246
247/// A gRPC metadata key to indicate the caller id of a client.
248pub const PERSIST_PUBSUB_CALLER_KEY: &str = "persist-pubsub-caller-id";
249
250/// Client configuration for connecting to a remote PubSub server.
251#[derive(Debug)]
252pub struct PersistPubSubClientConfig {
253    /// Connection address for the pubsub server, e.g. `http://localhost:6879`
254    pub url: String,
255    /// A caller ID for the client. Used for debugging.
256    pub caller_id: String,
257    /// A copy of [PersistConfig]
258    pub persist_cfg: PersistConfig,
259}
260
261/// A [PersistPubSubClient] implementation backed by gRPC.
262///
263/// Returns a [PubSubClientConnection] backed by channels that submit and receive
264/// messages to and from a long-lived bidirectional gRPC stream. The gRPC stream
265/// will be transparently reestablished if the connection is lost.
266#[derive(Debug)]
267pub struct GrpcPubSubClient;
268
269impl GrpcPubSubClient {
270    async fn reconnect_to_server_forever(
271        send_requests: tokio::sync::broadcast::Sender<ProtoPubSubMessage>,
272        receiver_input: &tokio::sync::mpsc::Sender<ProtoPubSubMessage>,
273        sender: Arc<SubscriptionTrackingSender>,
274        metadata: MetadataMap,
275        config: PersistPubSubClientConfig,
276        metrics: Arc<Metrics>,
277    ) {
278        // Once enabled, the PubSub server cannot be disabled or otherwise
279        // reconfigured. So we wait for at least one configuration sync to
280        // complete. This gives `environmentd` at least one chance to update
281        // PubSub configuration parameters. See database-issues#7168 for details.
282        config.persist_cfg.configs_synced_once().await;
283
284        let mut is_first_connection_attempt = true;
285        loop {
286            let sender = Arc::clone(&sender);
287            metrics.pubsub_client.grpc_connection.connected.set(0);
288
289            if !PUBSUB_CLIENT_ENABLED.get(&config.persist_cfg) {
290                tokio::time::sleep(Duration::from_secs(5)).await;
291                continue;
292            }
293
294            // add a bit of backoff when reconnecting after some network/server failure
295            if is_first_connection_attempt {
296                is_first_connection_attempt = false;
297            } else {
298                tokio::time::sleep(PUBSUB_RECONNECT_BACKOFF.get(&config.persist_cfg)).await;
299            }
300
301            info!("Connecting to Persist PubSub: {}", config.url);
302            let client = mz_ore::retry::Retry::default()
303                .clamp_backoff(PUBSUB_CONNECT_MAX_BACKOFF.get(&config.persist_cfg))
304                .retry_async(|_| async {
305                    metrics
306                        .pubsub_client
307                        .grpc_connection
308                        .connect_call_attempt_count
309                        .inc();
310                    let endpoint = match Endpoint::from_str(&config.url) {
311                        Ok(endpoint) => endpoint,
312                        Err(err) => return RetryResult::FatalErr(err),
313                    };
314                    ProtoPersistPubSubClient::connect(
315                        endpoint
316                            .connect_timeout(
317                                PUBSUB_CONNECT_ATTEMPT_TIMEOUT.get(&config.persist_cfg),
318                            )
319                            .timeout(PUBSUB_REQUEST_TIMEOUT.get(&config.persist_cfg)),
320                    )
321                    .await
322                    .into()
323                })
324                .await;
325
326            let mut client = match client {
327                Ok(client) => client,
328                Err(err) => {
329                    error!("fatal error connecting to persist pubsub: {:?}", err);
330                    return;
331                }
332            };
333
334            metrics
335                .pubsub_client
336                .grpc_connection
337                .connection_established_count
338                .inc();
339            metrics.pubsub_client.grpc_connection.connected.set(1);
340
341            info!("Connected to Persist PubSub: {}", config.url);
342
343            let mut broadcast = BroadcastStream::new(send_requests.subscribe());
344            let broadcast_errors = metrics
345                .pubsub_client
346                .grpc_connection
347                .broadcast_recv_lagged_count
348                .clone();
349
350            // shard subscriptions are tracked by connection on the server, so if our
351            // gRPC stream is ever swapped out, we must inform the server which shards
352            // our client intended to be subscribed to.
353            let broadcast_messages = async_stream::stream! {
354                'reconnect: loop {
355                    // If we have active subscriptions, resend them.
356                    for id in sender.subscriptions() {
357                        debug!("re-subscribing to shard: {id}");
358                        let msg = proto_pub_sub_message::Message::Subscribe(
359                            ProtoSubscribe {
360                                shard_id: id.into_proto(),
361                            },
362                        );
363                        yield create_request(msg);
364                    }
365
366                    // Forward on messages from the broadcast channel, reconnecting if necessary.
367                    while let Some(message) = broadcast.next().await {
368                        debug!("sending pubsub message: {:?}", message);
369                        match message {
370                            Ok(message) => yield message,
371                            Err(BroadcastStreamRecvError::Lagged(i)) => {
372                                broadcast_errors.inc_by(i);
373                                continue 'reconnect;
374                            }
375                        }
376                    }
377                    debug!("exhausted pubsub broadcast stream; shutting down");
378                    break;
379                }
380            };
381            let pubsub_request =
382                Request::from_parts(metadata.clone(), Extensions::default(), broadcast_messages);
383
384            let responses = match client.pub_sub(pubsub_request).await {
385                Ok(response) => response.into_inner(),
386                Err(err) => {
387                    warn!("pub_sub rpc error: {:?}", err);
388                    continue;
389                }
390            };
391
392            let stream_completed = GrpcPubSubClient::consume_grpc_stream(
393                responses,
394                receiver_input,
395                &config,
396                metrics.as_ref(),
397            )
398            .await;
399
400            match stream_completed {
401                // common case: reconnect due to some transient error
402                Ok(_) => continue,
403                // uncommon case: we should stop connecting to the PubSub server entirely.
404                // in practice, we should only see this during shut down.
405                Err(err) => {
406                    warn!("shutting down connection loop to Persist PubSub: {}", err);
407                    return;
408                }
409            }
410        }
411    }
412
413    async fn consume_grpc_stream(
414        mut responses: Streaming<ProtoPubSubMessage>,
415        receiver_input: &Sender<ProtoPubSubMessage>,
416        config: &PersistPubSubClientConfig,
417        metrics: &Metrics,
418    ) -> Result<(), Error> {
419        loop {
420            if !PUBSUB_CLIENT_ENABLED.get(&config.persist_cfg) {
421                return Ok(());
422            }
423
424            debug!("awaiting next pubsub response");
425            match responses.next().await {
426                Some(Ok(message)) => {
427                    debug!("received pubsub message: {:?}", message);
428                    match receiver_input.send(message).await {
429                        Ok(_) => {}
430                        // if the receiver has dropped, we can drop our
431                        // no-longer-needed grpc connection entirely.
432                        Err(err) => {
433                            return Err(anyhow!("closing pubsub grpc client connection: {}", err));
434                        }
435                    }
436                }
437                Some(Err(err)) => {
438                    metrics.pubsub_client.grpc_connection.grpc_error_count.inc();
439                    warn!("pubsub client error: {:?}", err);
440                    return Ok(());
441                }
442                None => return Ok(()),
443            }
444        }
445    }
446}
447
448impl PersistPubSubClient for GrpcPubSubClient {
449    fn connect(config: PersistPubSubClientConfig, metrics: Arc<Metrics>) -> PubSubClientConnection {
450        // Create a stable channel for our client to transmit message into our gRPC stream. We use a
451        // broadcast to allow us to create new Receivers on demand, in case the underlying gRPC stream
452        // is swapped out (e.g. due to connection failure). It is expected that only 1 Receiver is
453        // ever active at a given time.
454        let (send_requests, _) = tokio::sync::broadcast::channel(
455            PUBSUB_CLIENT_SENDER_CHANNEL_SIZE.get(&config.persist_cfg),
456        );
457        // Create a stable channel to receive messages from our gRPC stream. The input end lives inside
458        // a task that continuously reads from the active gRPC stream, decoupling the `PubSubReceiver`
459        // from the lifetime of a specific gRPC connection.
460        let (receiver_input, receiver_output) = tokio::sync::mpsc::channel(
461            PUBSUB_CLIENT_RECEIVER_CHANNEL_SIZE.get(&config.persist_cfg),
462        );
463
464        let sender = Arc::new(SubscriptionTrackingSender::new(Arc::new(
465            GrpcPubSubSender {
466                metrics: Arc::clone(&metrics),
467                requests: send_requests.clone(),
468            },
469        )));
470        let pubsub_sender = Arc::clone(&sender);
471        mz_ore::task::spawn(
472            || "persist::rpc::client::connection".to_string(),
473            async move {
474                let mut metadata = MetadataMap::new();
475                metadata.insert(
476                    AsciiMetadataKey::from_static(PERSIST_PUBSUB_CALLER_KEY),
477                    AsciiMetadataValue::try_from(&config.caller_id)
478                        .unwrap_or_else(|_| AsciiMetadataValue::from_static("unknown")),
479                );
480
481                GrpcPubSubClient::reconnect_to_server_forever(
482                    send_requests,
483                    &receiver_input,
484                    pubsub_sender,
485                    metadata,
486                    config,
487                    metrics,
488                )
489                .await;
490            },
491        );
492
493        PubSubClientConnection {
494            sender,
495            receiver: Box::new(ReceiverStream::new(receiver_output)),
496        }
497    }
498}
499
500/// An internal, gRPC-backed implementation of [PubSubSender].
501struct GrpcPubSubSender {
502    metrics: Arc<Metrics>,
503    requests: tokio::sync::broadcast::Sender<ProtoPubSubMessage>,
504}
505
506impl Debug for GrpcPubSubSender {
507    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
508        let GrpcPubSubSender {
509            metrics: _metrics,
510            requests: _requests,
511        } = self;
512
513        write!(f, "GrpcPubSubSender")
514    }
515}
516
517fn create_request(message: proto_pub_sub_message::Message) -> ProtoPubSubMessage {
518    let now = SystemTime::now()
519        .duration_since(SystemTime::UNIX_EPOCH)
520        .expect("failed to get millis since epoch");
521
522    ProtoPubSubMessage {
523        timestamp: Some(now.into_proto()),
524        message: Some(message),
525    }
526}
527
528impl GrpcPubSubSender {
529    fn send(&self, message: proto_pub_sub_message::Message, metrics: &PubSubClientCallMetrics) {
530        let size = message.encoded_len();
531
532        match self.requests.send(create_request(message)) {
533            Ok(_) => {
534                metrics.succeeded.inc();
535                metrics.bytes_sent.inc_by(u64::cast_from(size));
536            }
537            Err(err) => {
538                metrics.failed.inc();
539                debug!("error sending client message: {}", err);
540            }
541        }
542    }
543}
544
545impl PubSubSenderInternal for GrpcPubSubSender {
546    fn push_diff(&self, shard_id: &ShardId, diff: &VersionedData) {
547        self.send(
548            proto_pub_sub_message::Message::PushDiff(ProtoPushDiff {
549                shard_id: shard_id.into_proto(),
550                seqno: diff.seqno.into_proto(),
551                diff: diff.data.clone(),
552            }),
553            &self.metrics.pubsub_client.sender.push,
554        )
555    }
556
557    fn subscribe(&self, shard_id: &ShardId) {
558        self.send(
559            proto_pub_sub_message::Message::Subscribe(ProtoSubscribe {
560                shard_id: shard_id.into_proto(),
561            }),
562            &self.metrics.pubsub_client.sender.subscribe,
563        )
564    }
565
566    fn unsubscribe(&self, shard_id: &ShardId) {
567        self.send(
568            proto_pub_sub_message::Message::Unsubscribe(ProtoUnsubscribe {
569                shard_id: shard_id.into_proto(),
570            }),
571            &self.metrics.pubsub_client.sender.unsubscribe,
572        )
573    }
574}
575
576/// An wrapper for a [PubSubSenderInternal] that implements [PubSubSender]
577/// by maintaining a map of active shard subscriptions to their tokens.
578#[derive(Debug)]
579struct SubscriptionTrackingSender {
580    delegate: Arc<dyn PubSubSenderInternal>,
581    subscribes: Arc<Mutex<BTreeMap<ShardId, Weak<ShardSubscriptionToken>>>>,
582}
583
584impl SubscriptionTrackingSender {
585    fn new(sender: Arc<dyn PubSubSenderInternal>) -> Self {
586        Self {
587            delegate: sender,
588            subscribes: Default::default(),
589        }
590    }
591
592    fn subscriptions(&self) -> Vec<ShardId> {
593        let mut subscribes = self.subscribes.lock().expect("lock");
594        let mut out = Vec::with_capacity(subscribes.len());
595        subscribes.retain(|shard_id, token| {
596            if token.upgrade().is_none() {
597                false
598            } else {
599                debug!("reconnecting to: {}", shard_id);
600                out.push(*shard_id);
601                true
602            }
603        });
604        out
605    }
606}
607
608impl PubSubSender for SubscriptionTrackingSender {
609    fn push_diff(&self, shard_id: &ShardId, diff: &VersionedData) {
610        self.delegate.push_diff(shard_id, diff)
611    }
612
613    fn subscribe(self: Arc<Self>, shard_id: &ShardId) -> Arc<ShardSubscriptionToken> {
614        let mut subscribes = self.subscribes.lock().expect("lock");
615        if let Some(token) = subscribes.get(shard_id) {
616            match token.upgrade() {
617                None => assert!(subscribes.remove(shard_id).is_some()),
618                Some(token) => {
619                    return Arc::clone(&token);
620                }
621            }
622        }
623
624        let pubsub_sender = Arc::clone(&self.delegate);
625        let token = Arc::new(ShardSubscriptionToken {
626            shard_id: *shard_id,
627            sender: pubsub_sender,
628        });
629
630        assert!(
631            subscribes
632                .insert(*shard_id, Arc::downgrade(&token))
633                .is_none()
634        );
635
636        self.delegate.subscribe(shard_id);
637
638        token
639    }
640}
641
642/// A wrapper intended to provide client-side metrics for a connection
643/// that communicates directly with the server state, such as one created
644/// by [PersistGrpcPubSubServer::new_same_process_connection].
645#[derive(Debug)]
646pub struct MetricsSameProcessPubSubSender {
647    delegate_subscribe: bool,
648    metrics: Arc<Metrics>,
649    delegate: Arc<dyn PubSubSender>,
650}
651
652impl MetricsSameProcessPubSubSender {
653    /// Returns a new [MetricsSameProcessPubSubSender], wrapping the given
654    /// `Arc<dyn PubSubSender>`'s calls to provide client-side metrics.
655    pub fn new(
656        cfg: &PersistConfig,
657        pubsub_sender: Arc<dyn PubSubSender>,
658        metrics: Arc<Metrics>,
659    ) -> Self {
660        Self {
661            delegate_subscribe: PUBSUB_SAME_PROCESS_DELEGATE_ENABLED.get(cfg),
662            delegate: pubsub_sender,
663            metrics,
664        }
665    }
666}
667
668impl PubSubSender for MetricsSameProcessPubSubSender {
669    fn push_diff(&self, shard_id: &ShardId, diff: &VersionedData) {
670        self.delegate.push_diff(shard_id, diff);
671        self.metrics.pubsub_client.sender.push.succeeded.inc();
672    }
673
674    fn subscribe(self: Arc<Self>, shard_id: &ShardId) -> Arc<ShardSubscriptionToken> {
675        if self.delegate_subscribe {
676            let delegate = Arc::clone(&self.delegate);
677            delegate.subscribe(shard_id)
678        } else {
679            // Create a no-op token that does not subscribe nor unsubscribe.
680            // This is ideal for single-process persist setups, since the sender and
681            // receiver should already share a state cache... but if the diffs are
682            // generated remotely but applied on the server, this may cause us to fall
683            // back to polling consensus.
684            Arc::new(ShardSubscriptionToken {
685                shard_id: *shard_id,
686                sender: Arc::new(NoopPubSubSender),
687            })
688        }
689    }
690}
691
692#[derive(Debug)]
693pub(crate) struct NoopPubSubSender;
694
695impl PubSubSenderInternal for NoopPubSubSender {
696    fn push_diff(&self, _shard_id: &ShardId, _diff: &VersionedData) {}
697    fn subscribe(&self, _shard_id: &ShardId) {}
698    fn unsubscribe(&self, _shard_id: &ShardId) {}
699}
700
701impl PubSubSender for NoopPubSubSender {
702    fn push_diff(&self, _shard_id: &ShardId, _diff: &VersionedData) {}
703
704    fn subscribe(self: Arc<Self>, shard_id: &ShardId) -> Arc<ShardSubscriptionToken> {
705        Arc::new(ShardSubscriptionToken {
706            shard_id: *shard_id,
707            sender: self,
708        })
709    }
710}
711
712/// Spawns a Tokio task that consumes a [PubSubReceiver], applying its diffs to a [StateCache].
713pub(crate) fn subscribe_state_cache_to_pubsub(
714    cache: Arc<StateCache>,
715    mut pubsub_receiver: Box<dyn PubSubReceiver>,
716) -> JoinHandle<()> {
717    let mut state_refs: HashMap<ShardId, Weak<dyn DynState>> = HashMap::new();
718    let receiver_metrics = cache.metrics.pubsub_client.receiver.clone();
719
720    mz_ore::task::spawn(
721        || "persist::rpc::client::state_cache_diff_apply",
722        async move {
723            while let Some(msg) = pubsub_receiver.next().await {
724                match msg.message {
725                    Some(proto_pub_sub_message::Message::PushDiff(diff)) => {
726                        receiver_metrics.push_received.inc();
727                        let shard_id = diff.shard_id.into_rust().expect("valid shard id");
728                        let diff = VersionedData {
729                            seqno: diff.seqno.into_rust().expect("valid SeqNo"),
730                            data: diff.diff,
731                        };
732                        debug!(
733                            "applying pubsub diff {} {} {}",
734                            shard_id,
735                            diff.seqno,
736                            diff.data.len()
737                        );
738
739                        let mut pushed_diff = false;
740                        if let Some(state_ref) = state_refs.get(&shard_id) {
741                            // common case: we have a reference to the shard state already
742                            // and can apply our diff directly.
743                            if let Some(state) = state_ref.upgrade() {
744                                state.push_diff(diff.clone());
745                                pushed_diff = true;
746                                receiver_metrics.state_pushed_diff_fast_path.inc();
747                            }
748                        }
749
750                        if !pushed_diff {
751                            // uncommon case: we either don't have a reference yet, or ours
752                            // is out-of-date (e.g. the shard was dropped and then re-added
753                            // to StateCache). here we'll fetch the latest, try to apply the
754                            // diff again, and update our local reference.
755                            let state_ref = cache.get_state_weak(&shard_id);
756                            match state_ref {
757                                None => {
758                                    state_refs.remove(&shard_id);
759                                }
760                                Some(state_ref) => {
761                                    if let Some(state) = state_ref.upgrade() {
762                                        state.push_diff(diff);
763                                        pushed_diff = true;
764                                        state_refs.insert(shard_id, state_ref);
765                                    } else {
766                                        state_refs.remove(&shard_id);
767                                    }
768                                }
769                            }
770
771                            if pushed_diff {
772                                receiver_metrics.state_pushed_diff_slow_path_succeeded.inc();
773                            } else {
774                                receiver_metrics.state_pushed_diff_slow_path_failed.inc();
775                            }
776                        }
777
778                        if let Some(send_timestamp) = msg.timestamp {
779                            let send_timestamp =
780                                send_timestamp.into_rust().expect("valid timestamp");
781                            let now = SystemTime::now()
782                                .duration_since(SystemTime::UNIX_EPOCH)
783                                .expect("failed to get millis since epoch");
784                            receiver_metrics
785                                .approx_diff_latency_seconds
786                                .observe((now.saturating_sub(send_timestamp)).as_secs_f64());
787                        }
788                    }
789                    ref msg @ None | ref msg @ Some(_) => {
790                        warn!("pubsub client received unexpected message: {:?}", msg);
791                        receiver_metrics.unknown_message_received.inc();
792                    }
793                }
794            }
795        },
796    )
797}
798
799/// Internal state of a PubSub server implementation.
800#[derive(Debug)]
801pub(crate) struct PubSubState {
802    /// Assigns a unique ID to each incoming connection.
803    connection_id_counter: AtomicUsize,
804    /// Maintains a mapping of `ShardId --> [ConnectionId -> Tx]`.
805    shard_subscribers:
806        Arc<RwLock<BTreeMap<ShardId, BTreeMap<usize, Sender<Result<ProtoPubSubMessage, Status>>>>>>,
807    /// Active connections.
808    connections: Arc<RwLock<HashSet<usize>>>,
809    /// Server-side metrics.
810    metrics: Arc<PubSubServerMetrics>,
811}
812
813impl PubSubState {
814    fn new_connection(
815        self: Arc<Self>,
816        notifier: Sender<Result<ProtoPubSubMessage, Status>>,
817    ) -> PubSubConnection {
818        let connection_id = self.connection_id_counter.fetch_add(1, Ordering::SeqCst);
819        {
820            debug!("inserting connid: {}", connection_id);
821            let mut connections = self.connections.write().expect("lock");
822            assert!(connections.insert(connection_id));
823        }
824
825        self.metrics.active_connections.inc();
826        PubSubConnection {
827            connection_id,
828            notifier,
829            state: self,
830        }
831    }
832
833    fn remove_connection(&self, connection_id: usize) {
834        let now = Instant::now();
835
836        {
837            debug!("removing connid: {}", connection_id);
838            let mut connections = self.connections.write().expect("lock");
839            assert!(
840                connections.remove(&connection_id),
841                "unknown connection id: {}",
842                connection_id
843            );
844        }
845
846        {
847            let mut subscribers = self.shard_subscribers.write().expect("lock poisoned");
848            subscribers.retain(|_shard, connections_for_shard| {
849                connections_for_shard.remove(&connection_id);
850                !connections_for_shard.is_empty()
851            });
852        }
853
854        self.metrics
855            .connection_cleanup_seconds
856            .inc_by(now.elapsed().as_secs_f64());
857        self.metrics.active_connections.dec();
858    }
859
860    fn push_diff(&self, connection_id: usize, shard_id: &ShardId, data: &VersionedData) {
861        let now = Instant::now();
862        self.metrics.push_call_count.inc();
863
864        assert!(
865            self.connections
866                .read()
867                .expect("lock")
868                .contains(&connection_id),
869            "unknown connection id: {}",
870            connection_id
871        );
872
873        let subscribers = self.shard_subscribers.read().expect("lock poisoned");
874        if let Some(subscribed_connections) = subscribers.get(shard_id) {
875            let mut num_sent = 0;
876            let mut data_size = 0;
877
878            for (subscribed_conn_id, tx) in subscribed_connections {
879                // skip sending the diff back to the original sender
880                if *subscribed_conn_id == connection_id {
881                    continue;
882                }
883                debug!(
884                    "server forwarding req to conn {}: {} {} {}",
885                    subscribed_conn_id,
886                    &shard_id,
887                    data.seqno,
888                    data.data.len()
889                );
890                let req = create_request(proto_pub_sub_message::Message::PushDiff(ProtoPushDiff {
891                    seqno: data.seqno.into_proto(),
892                    shard_id: shard_id.to_string(),
893                    diff: Bytes::clone(&data.data),
894                }));
895                data_size = req.encoded_len();
896                match tx.try_send(Ok(req)) {
897                    Ok(_) => {
898                        num_sent += 1;
899                    }
900                    Err(TrySendError::Full(_)) => {
901                        self.metrics.broadcasted_diff_dropped_channel_full.inc();
902                    }
903                    Err(TrySendError::Closed(_)) => {}
904                };
905            }
906
907            self.metrics.broadcasted_diff_count.inc_by(num_sent);
908            self.metrics
909                .broadcasted_diff_bytes
910                .inc_by(num_sent * u64::cast_from(data_size));
911        }
912
913        self.metrics
914            .push_seconds
915            .inc_by(now.elapsed().as_secs_f64());
916    }
917
918    fn subscribe(
919        &self,
920        connection_id: usize,
921        notifier: Sender<Result<ProtoPubSubMessage, Status>>,
922        shard_id: &ShardId,
923    ) {
924        let now = Instant::now();
925        self.metrics.subscribe_call_count.inc();
926
927        assert!(
928            self.connections
929                .read()
930                .expect("lock")
931                .contains(&connection_id),
932            "unknown connection id: {}",
933            connection_id
934        );
935
936        {
937            let mut subscribed_shards = self.shard_subscribers.write().expect("lock poisoned");
938            subscribed_shards
939                .entry(*shard_id)
940                .or_default()
941                .insert(connection_id, notifier);
942        }
943
944        self.metrics
945            .subscribe_seconds
946            .inc_by(now.elapsed().as_secs_f64());
947    }
948
949    fn unsubscribe(&self, connection_id: usize, shard_id: &ShardId) {
950        let now = Instant::now();
951        self.metrics.unsubscribe_call_count.inc();
952
953        assert!(
954            self.connections
955                .read()
956                .expect("lock")
957                .contains(&connection_id),
958            "unknown connection id: {}",
959            connection_id
960        );
961
962        {
963            let mut subscribed_shards = self.shard_subscribers.write().expect("lock poisoned");
964            if let Entry::Occupied(mut entry) = subscribed_shards.entry(*shard_id) {
965                let subscribed_connections = entry.get_mut();
966                subscribed_connections.remove(&connection_id);
967
968                if subscribed_connections.is_empty() {
969                    entry.remove_entry();
970                }
971            }
972        }
973
974        self.metrics
975            .unsubscribe_seconds
976            .inc_by(now.elapsed().as_secs_f64());
977    }
978
979    #[cfg(test)]
980    fn new_for_test() -> Self {
981        Self {
982            connection_id_counter: AtomicUsize::new(0),
983            shard_subscribers: Default::default(),
984            connections: Default::default(),
985            metrics: Arc::new(PubSubServerMetrics::new(&MetricsRegistry::new())),
986        }
987    }
988
989    #[cfg(test)]
990    fn active_connections(&self) -> HashSet<usize> {
991        self.connections.read().expect("lock").clone()
992    }
993
994    #[cfg(test)]
995    fn subscriptions(&self, connection_id: usize) -> HashSet<ShardId> {
996        let mut shards = HashSet::new();
997
998        let subscribers = self.shard_subscribers.read().expect("lock");
999        for (shard, subscribed_connections) in subscribers.iter() {
1000            if subscribed_connections.contains_key(&connection_id) {
1001                shards.insert(*shard);
1002            }
1003        }
1004
1005        shards
1006    }
1007
1008    #[cfg(test)]
1009    fn shard_subscription_counts(&self) -> mz_ore::collections::HashMap<ShardId, usize> {
1010        let mut shards = mz_ore::collections::HashMap::new();
1011
1012        let subscribers = self.shard_subscribers.read().expect("lock");
1013        for (shard, subscribed_connections) in subscribers.iter() {
1014            shards.insert(*shard, subscribed_connections.len());
1015        }
1016
1017        shards
1018    }
1019}
1020
1021/// A gRPC-based implementation of a Persist PubSub server.
1022#[derive(Debug)]
1023pub struct PersistGrpcPubSubServer {
1024    cfg: PersistConfig,
1025    state: Arc<PubSubState>,
1026}
1027
1028impl PersistGrpcPubSubServer {
1029    /// Creates a new [PersistGrpcPubSubServer].
1030    pub fn new(cfg: &PersistConfig, metrics_registry: &MetricsRegistry) -> Self {
1031        let metrics = PubSubServerMetrics::new(metrics_registry);
1032        let state = Arc::new(PubSubState {
1033            connection_id_counter: AtomicUsize::new(0),
1034            shard_subscribers: Default::default(),
1035            connections: Default::default(),
1036            metrics: Arc::new(metrics),
1037        });
1038
1039        PersistGrpcPubSubServer {
1040            cfg: cfg.clone(),
1041            state,
1042        }
1043    }
1044
1045    /// Creates a connection to [PersistGrpcPubSubServer] that is directly connected
1046    /// to the server state. Calls into this connection do not go over the network
1047    /// nor require message serde.
1048    pub fn new_same_process_connection(&self) -> PubSubClientConnection {
1049        let (tx, rx) =
1050            tokio::sync::mpsc::channel(PUBSUB_CLIENT_RECEIVER_CHANNEL_SIZE.get(&self.cfg));
1051        let sender: Arc<dyn PubSubSender> = Arc::new(SubscriptionTrackingSender::new(Arc::new(
1052            Arc::clone(&self.state).new_connection(tx),
1053        )));
1054
1055        PubSubClientConnection {
1056            sender,
1057            receiver: Box::new(
1058                ReceiverStream::new(rx).map(|x| x.expect("cannot receive grpc errors locally")),
1059            ),
1060        }
1061    }
1062
1063    /// Starts the gRPC server. Consumes `self` and runs until the task is cancelled.
1064    pub async fn serve(self, listen_addr: SocketAddr) -> Result<(), anyhow::Error> {
1065        // Increase the default message decoding limit to avoid unnecessary panics
1066        tonic::transport::Server::builder()
1067            .add_service(ProtoPersistPubSubServer::new(self).max_decoding_message_size(usize::MAX))
1068            .serve(listen_addr)
1069            .await?;
1070        Ok(())
1071    }
1072
1073    /// Starts the gRPC server with the given listener stream.
1074    /// Consumes `self` and runs until the task is cancelled.
1075    pub async fn serve_with_stream(
1076        self,
1077        listener: tokio_stream::wrappers::TcpListenerStream,
1078    ) -> Result<(), anyhow::Error> {
1079        tonic::transport::Server::builder()
1080            .add_service(ProtoPersistPubSubServer::new(self))
1081            .serve_with_incoming(listener)
1082            .await?;
1083        Ok(())
1084    }
1085}
1086
1087#[async_trait]
1088impl proto_persist_pub_sub_server::ProtoPersistPubSub for PersistGrpcPubSubServer {
1089    type PubSubStream = Pin<Box<dyn Stream<Item = Result<ProtoPubSubMessage, Status>> + Send>>;
1090
1091    #[mz_ore::instrument(name = "persist::rpc::server", level = "info")]
1092    async fn pub_sub(
1093        &self,
1094        request: Request<Streaming<ProtoPubSubMessage>>,
1095    ) -> Result<Response<Self::PubSubStream>, Status> {
1096        let caller_id = request
1097            .metadata()
1098            .get(AsciiMetadataKey::from_static(PERSIST_PUBSUB_CALLER_KEY))
1099            .map(|key| key.to_str().ok())
1100            .flatten()
1101            .map(|key| key.to_string())
1102            .unwrap_or_else(|| "unknown".to_string());
1103        info!("Received Persist PubSub connection from: {:?}", caller_id);
1104
1105        let mut in_stream = request.into_inner();
1106        let (tx, rx) =
1107            tokio::sync::mpsc::channel(PUBSUB_SERVER_CONNECTION_CHANNEL_SIZE.get(&self.cfg));
1108
1109        let caller = caller_id.clone();
1110        let cfg = Arc::clone(&self.cfg.configs);
1111        let server_state = Arc::clone(&self.state);
1112        // this spawn here to cleanup after connection error / disconnect, otherwise the stream
1113        // would not be polled after the connection drops. in our case, we want to clear the
1114        // connection and its subscriptions from our shared state when it drops.
1115        let connection_span = info_span!("connection", caller_id);
1116        mz_ore::task::spawn(
1117            || format!("persist_pubsub_connection({})", caller),
1118            async move {
1119                let connection = server_state.new_connection(tx);
1120                while let Some(result) = in_stream.next().await {
1121                    let req = match result {
1122                        Ok(req) => req,
1123                        Err(err) => {
1124                            warn!("pubsub connection err: {}", err);
1125                            break;
1126                        }
1127                    };
1128
1129                    match req.message {
1130                        None => {
1131                            warn!("received empty message from: {}", caller_id);
1132                        }
1133                        Some(proto_pub_sub_message::Message::PushDiff(req)) => {
1134                            let shard_id = req.shard_id.parse().expect("valid shard id");
1135                            let diff = VersionedData {
1136                                seqno: req.seqno.into_rust().expect("valid seqno"),
1137                                data: req.diff.clone(),
1138                            };
1139                            if PUBSUB_PUSH_DIFF_ENABLED.get(&cfg) {
1140                                connection.push_diff(&shard_id, &diff);
1141                            }
1142                        }
1143                        Some(proto_pub_sub_message::Message::Subscribe(diff)) => {
1144                            let shard_id = diff.shard_id.parse().expect("valid shard id");
1145                            connection.subscribe(&shard_id);
1146                        }
1147                        Some(proto_pub_sub_message::Message::Unsubscribe(diff)) => {
1148                            let shard_id = diff.shard_id.parse().expect("valid shard id");
1149                            connection.unsubscribe(&shard_id);
1150                        }
1151                    }
1152                }
1153
1154                info!("Persist PubSub connection ended: {:?}", caller_id);
1155            }
1156            .instrument(connection_span),
1157        );
1158
1159        let out_stream: Self::PubSubStream = Box::pin(ReceiverStream::new(rx));
1160        Ok(Response::new(out_stream))
1161    }
1162}
1163
1164/// An active connection managed by [PubSubState].
1165///
1166/// When dropped, removes itself from [PubSubState], clearing all of its subscriptions.
1167#[derive(Debug)]
1168pub(crate) struct PubSubConnection {
1169    connection_id: usize,
1170    notifier: Sender<Result<ProtoPubSubMessage, Status>>,
1171    state: Arc<PubSubState>,
1172}
1173
1174impl PubSubSenderInternal for PubSubConnection {
1175    fn push_diff(&self, shard_id: &ShardId, diff: &VersionedData) {
1176        self.state.push_diff(self.connection_id, shard_id, diff)
1177    }
1178
1179    fn subscribe(&self, shard_id: &ShardId) {
1180        self.state
1181            .subscribe(self.connection_id, self.notifier.clone(), shard_id)
1182    }
1183
1184    fn unsubscribe(&self, shard_id: &ShardId) {
1185        self.state.unsubscribe(self.connection_id, shard_id)
1186    }
1187}
1188
1189impl Drop for PubSubConnection {
1190    fn drop(&mut self) {
1191        self.state.remove_connection(self.connection_id)
1192    }
1193}
1194
1195#[cfg(test)]
1196mod pubsub_state {
1197    use std::str::FromStr;
1198    use std::sync::Arc;
1199    use std::sync::LazyLock;
1200
1201    use bytes::Bytes;
1202    use mz_ore::collections::HashSet;
1203    use mz_persist::location::{SeqNo, VersionedData};
1204    use mz_proto::RustType;
1205    use tokio::sync::mpsc::Receiver;
1206    use tokio::sync::mpsc::error::TryRecvError;
1207    use tonic::Status;
1208
1209    use crate::ShardId;
1210    use crate::internal::service::ProtoPubSubMessage;
1211    use crate::internal::service::proto_pub_sub_message::Message;
1212    use crate::rpc::{PubSubSenderInternal, PubSubState};
1213
1214    static SHARD_ID_0: LazyLock<ShardId> =
1215        LazyLock::new(|| ShardId::from_str("s00000000-0000-0000-0000-000000000000").unwrap());
1216    static SHARD_ID_1: LazyLock<ShardId> =
1217        LazyLock::new(|| ShardId::from_str("s11111111-1111-1111-1111-111111111111").unwrap());
1218
1219    const VERSIONED_DATA_0: VersionedData = VersionedData {
1220        seqno: SeqNo(0),
1221        data: Bytes::from_static(&[0, 1, 2, 3]),
1222    };
1223
1224    const VERSIONED_DATA_1: VersionedData = VersionedData {
1225        seqno: SeqNo(1),
1226        data: Bytes::from_static(&[4, 5, 6, 7]),
1227    };
1228
1229    #[mz_ore::test]
1230    #[should_panic(expected = "unknown connection id: 100")]
1231    fn test_zero_connections_push_diff() {
1232        let state = Arc::new(PubSubState::new_for_test());
1233        state.push_diff(100, &SHARD_ID_0, &VERSIONED_DATA_0);
1234    }
1235
1236    #[mz_ore::test]
1237    #[should_panic(expected = "unknown connection id: 100")]
1238    fn test_zero_connections_subscribe() {
1239        let state = Arc::new(PubSubState::new_for_test());
1240        let (tx, _) = tokio::sync::mpsc::channel(100);
1241        state.subscribe(100, tx, &SHARD_ID_0);
1242    }
1243
1244    #[mz_ore::test]
1245    #[should_panic(expected = "unknown connection id: 100")]
1246    fn test_zero_connections_unsubscribe() {
1247        let state = Arc::new(PubSubState::new_for_test());
1248        state.unsubscribe(100, &SHARD_ID_0);
1249    }
1250
1251    #[mz_ore::test]
1252    #[should_panic(expected = "unknown connection id: 100")]
1253    fn test_zero_connections_remove() {
1254        let state = Arc::new(PubSubState::new_for_test());
1255        state.remove_connection(100)
1256    }
1257
1258    #[mz_ore::test]
1259    fn test_single_connection() {
1260        let state = Arc::new(PubSubState::new_for_test());
1261
1262        let (tx, mut rx) = tokio::sync::mpsc::channel(100);
1263        let connection = Arc::clone(&state).new_connection(tx);
1264
1265        assert_eq!(
1266            state.active_connections(),
1267            HashSet::from([connection.connection_id])
1268        );
1269
1270        // no messages should have been broadcasted yet
1271        assert!(matches!(rx.try_recv(), Err(TryRecvError::Empty)));
1272
1273        connection.push_diff(
1274            &SHARD_ID_0,
1275            &VersionedData {
1276                seqno: SeqNo::minimum(),
1277                data: Bytes::new(),
1278            },
1279        );
1280
1281        // server should not broadcast a message back to originating client
1282        assert!(matches!(rx.try_recv(), Err(TryRecvError::Empty)));
1283
1284        // a connection can subscribe to a shard
1285        connection.subscribe(&SHARD_ID_0);
1286        assert_eq!(
1287            state.subscriptions(connection.connection_id),
1288            HashSet::from([SHARD_ID_0.clone()])
1289        );
1290
1291        // a connection can unsubscribe
1292        connection.unsubscribe(&SHARD_ID_0);
1293        assert!(state.subscriptions(connection.connection_id).is_empty());
1294
1295        // a connection can subscribe to many shards
1296        connection.subscribe(&SHARD_ID_0);
1297        connection.subscribe(&SHARD_ID_1);
1298        assert_eq!(
1299            state.subscriptions(connection.connection_id),
1300            HashSet::from([*SHARD_ID_0, *SHARD_ID_1])
1301        );
1302
1303        // and to a single shard many times idempotently
1304        connection.subscribe(&SHARD_ID_0);
1305        connection.subscribe(&SHARD_ID_0);
1306        assert_eq!(
1307            state.subscriptions(connection.connection_id),
1308            HashSet::from([*SHARD_ID_0, *SHARD_ID_1])
1309        );
1310
1311        // dropping the connection should unsubscribe all shards and unregister the connection
1312        let connection_id = connection.connection_id;
1313        drop(connection);
1314        assert!(state.subscriptions(connection_id).is_empty());
1315        assert!(state.active_connections().is_empty());
1316    }
1317
1318    #[mz_ore::test]
1319    fn test_many_connection() {
1320        let state = Arc::new(PubSubState::new_for_test());
1321
1322        let (tx1, mut rx1) = tokio::sync::mpsc::channel(100);
1323        let conn1 = Arc::clone(&state).new_connection(tx1);
1324
1325        let (tx2, mut rx2) = tokio::sync::mpsc::channel(100);
1326        let conn2 = Arc::clone(&state).new_connection(tx2);
1327
1328        let (tx3, mut rx3) = tokio::sync::mpsc::channel(100);
1329        let conn3 = Arc::clone(&state).new_connection(tx3);
1330
1331        conn1.subscribe(&SHARD_ID_0);
1332        conn2.subscribe(&SHARD_ID_0);
1333        conn2.subscribe(&SHARD_ID_1);
1334
1335        assert_eq!(
1336            state.active_connections(),
1337            HashSet::from([
1338                conn1.connection_id,
1339                conn2.connection_id,
1340                conn3.connection_id
1341            ])
1342        );
1343
1344        // broadcast a diff to a shard subscribed to by several connections
1345        conn3.push_diff(&SHARD_ID_0, &VERSIONED_DATA_0);
1346        assert_push(&mut rx1, &SHARD_ID_0, &VERSIONED_DATA_0);
1347        assert_push(&mut rx2, &SHARD_ID_0, &VERSIONED_DATA_0);
1348        assert!(matches!(rx3.try_recv(), Err(TryRecvError::Empty)));
1349
1350        // broadcast a diff shared by publisher. it should not receive the diff back.
1351        conn1.push_diff(&SHARD_ID_0, &VERSIONED_DATA_0);
1352        assert!(matches!(rx1.try_recv(), Err(TryRecvError::Empty)));
1353        assert_push(&mut rx2, &SHARD_ID_0, &VERSIONED_DATA_0);
1354        assert!(matches!(rx3.try_recv(), Err(TryRecvError::Empty)));
1355
1356        // broadcast a diff to a shard subscribed to by one connection
1357        conn3.push_diff(&SHARD_ID_1, &VERSIONED_DATA_1);
1358        assert!(matches!(rx1.try_recv(), Err(TryRecvError::Empty)));
1359        assert_push(&mut rx2, &SHARD_ID_1, &VERSIONED_DATA_1);
1360        assert!(matches!(rx3.try_recv(), Err(TryRecvError::Empty)));
1361
1362        // broadcast a diff to a shard subscribed to by no connections
1363        conn2.unsubscribe(&SHARD_ID_1);
1364        conn3.push_diff(&SHARD_ID_1, &VERSIONED_DATA_1);
1365        assert!(matches!(rx1.try_recv(), Err(TryRecvError::Empty)));
1366        assert!(matches!(rx2.try_recv(), Err(TryRecvError::Empty)));
1367        assert!(matches!(rx3.try_recv(), Err(TryRecvError::Empty)));
1368
1369        // dropping connections unsubscribes them
1370        let conn1_id = conn1.connection_id;
1371        drop(conn1);
1372        conn3.push_diff(&SHARD_ID_0, &VERSIONED_DATA_0);
1373        assert!(matches!(rx1.try_recv(), Err(TryRecvError::Disconnected)));
1374        assert_push(&mut rx2, &SHARD_ID_0, &VERSIONED_DATA_0);
1375        assert!(matches!(rx3.try_recv(), Err(TryRecvError::Empty)));
1376
1377        assert!(state.subscriptions(conn1_id).is_empty());
1378        assert_eq!(
1379            state.subscriptions(conn2.connection_id),
1380            HashSet::from([*SHARD_ID_0])
1381        );
1382        assert_eq!(state.subscriptions(conn3.connection_id), HashSet::new());
1383        assert_eq!(
1384            state.active_connections(),
1385            HashSet::from([conn2.connection_id, conn3.connection_id])
1386        );
1387    }
1388
1389    fn assert_push(
1390        rx: &mut Receiver<Result<ProtoPubSubMessage, Status>>,
1391        shard: &ShardId,
1392        data: &VersionedData,
1393    ) {
1394        let message = rx
1395            .try_recv()
1396            .expect("message in channel")
1397            .expect("pubsub")
1398            .message
1399            .expect("proto contains message");
1400        match message {
1401            Message::PushDiff(x) => {
1402                assert_eq!(x.shard_id, shard.into_proto());
1403                assert_eq!(x.seqno, data.seqno.into_proto());
1404                assert_eq!(x.diff, data.data);
1405            }
1406            Message::Subscribe(_) | Message::Unsubscribe(_) => panic!("unexpected message type"),
1407        };
1408    }
1409}
1410
1411#[cfg(test)]
1412mod grpc {
1413    use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
1414    use std::str::FromStr;
1415    use std::sync::Arc;
1416    use std::time::{Duration, Instant};
1417
1418    use bytes::Bytes;
1419    use futures_util::FutureExt;
1420    use mz_dyncfg::ConfigUpdates;
1421    use mz_ore::assert_none;
1422    use mz_ore::collections::HashMap;
1423    use mz_ore::metrics::MetricsRegistry;
1424    use mz_persist::location::{SeqNo, VersionedData};
1425    use mz_proto::RustType;
1426    use std::sync::LazyLock;
1427    use tokio::net::TcpListener;
1428    use tokio_stream::StreamExt;
1429    use tokio_stream::wrappers::TcpListenerStream;
1430
1431    use crate::ShardId;
1432    use crate::cfg::PersistConfig;
1433    use crate::internal::service::ProtoPubSubMessage;
1434    use crate::internal::service::proto_pub_sub_message::Message;
1435    use crate::metrics::Metrics;
1436    use crate::rpc::{
1437        GrpcPubSubClient, PUBSUB_CLIENT_ENABLED, PUBSUB_RECONNECT_BACKOFF, PersistGrpcPubSubServer,
1438        PersistPubSubClient, PersistPubSubClientConfig, PubSubState,
1439    };
1440
1441    static SHARD_ID_0: LazyLock<ShardId> =
1442        LazyLock::new(|| ShardId::from_str("s00000000-0000-0000-0000-000000000000").unwrap());
1443    static SHARD_ID_1: LazyLock<ShardId> =
1444        LazyLock::new(|| ShardId::from_str("s11111111-1111-1111-1111-111111111111").unwrap());
1445    const VERSIONED_DATA_0: VersionedData = VersionedData {
1446        seqno: SeqNo(0),
1447        data: Bytes::from_static(&[0, 1, 2, 3]),
1448    };
1449    const VERSIONED_DATA_1: VersionedData = VersionedData {
1450        seqno: SeqNo(1),
1451        data: Bytes::from_static(&[4, 5, 6, 7]),
1452    };
1453
1454    const CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
1455    const SUBSCRIPTIONS_TIMEOUT: Duration = Duration::from_secs(3);
1456    const SERVER_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(2);
1457
1458    // NB: we use separate runtimes for client and server throughout these tests to cleanly drop
1459    // ALL tasks (including spawned child tasks) associated with one end of a connection, to most
1460    // closely model an actual disconnect.
1461
1462    #[mz_ore::test]
1463    #[cfg_attr(miri, ignore)] // error: unsupported operation: can't call foreign function `socket` on OS `linux`
1464    fn grpc_server() {
1465        let metrics = Arc::new(Metrics::new(
1466            &test_persist_config(),
1467            &MetricsRegistry::new(),
1468        ));
1469        let server_runtime = tokio::runtime::Runtime::new().expect("server runtime");
1470        let client_runtime = tokio::runtime::Runtime::new().expect("client runtime");
1471
1472        // start the server
1473        let (addr, tcp_listener_stream) = server_runtime.block_on(new_tcp_listener());
1474        let server_state = server_runtime.block_on(spawn_server(tcp_listener_stream));
1475
1476        // start a client.
1477        {
1478            let _guard = client_runtime.enter();
1479            mz_ore::task::spawn(|| "client".to_string(), async move {
1480                let client = GrpcPubSubClient::connect(
1481                    PersistPubSubClientConfig {
1482                        url: format!("http://{}", addr),
1483                        caller_id: "client".to_string(),
1484                        persist_cfg: test_persist_config(),
1485                    },
1486                    metrics,
1487                );
1488                let _token = client.sender.subscribe(&SHARD_ID_0);
1489                tokio::time::sleep(Duration::MAX).await;
1490            });
1491        }
1492
1493        // wait until the client is connected and subscribed
1494        server_runtime.block_on(async {
1495            poll_until_true(CONNECT_TIMEOUT, || {
1496                server_state.active_connections().len() == 1
1497            })
1498            .await;
1499            poll_until_true(SUBSCRIPTIONS_TIMEOUT, || {
1500                server_state.shard_subscription_counts() == HashMap::from([(*SHARD_ID_0, 1)])
1501            })
1502            .await
1503        });
1504
1505        // drop the client
1506        client_runtime.shutdown_timeout(SERVER_SHUTDOWN_TIMEOUT);
1507
1508        // server should notice the client dropping and clean up its state
1509        server_runtime.block_on(async {
1510            poll_until_true(CONNECT_TIMEOUT, || {
1511                server_state.active_connections().is_empty()
1512            })
1513            .await;
1514            poll_until_true(SUBSCRIPTIONS_TIMEOUT, || {
1515                server_state.shard_subscription_counts() == HashMap::new()
1516            })
1517            .await
1518        });
1519    }
1520
1521    #[mz_ore::test]
1522    #[cfg_attr(miri, ignore)] // error: unsupported operation: can't call foreign function `socket` on OS `linux`
1523    fn grpc_client_sender_reconnects() {
1524        let metrics = Arc::new(Metrics::new(
1525            &test_persist_config(),
1526            &MetricsRegistry::new(),
1527        ));
1528        let server_runtime = tokio::runtime::Runtime::new().expect("server runtime");
1529        let client_runtime = tokio::runtime::Runtime::new().expect("client runtime");
1530        let (addr, tcp_listener_stream) = server_runtime.block_on(new_tcp_listener());
1531
1532        // start a client
1533        let client = client_runtime.block_on(async {
1534            GrpcPubSubClient::connect(
1535                PersistPubSubClientConfig {
1536                    url: format!("http://{}", addr),
1537                    caller_id: "client".to_string(),
1538                    persist_cfg: test_persist_config(),
1539                },
1540                metrics,
1541            )
1542        });
1543
1544        // we can subscribe before connecting to the pubsub server
1545        let _token = Arc::clone(&client.sender).subscribe(&SHARD_ID_0);
1546        // we can subscribe and unsubscribe before connecting to the pubsub server
1547        let _token_2 = Arc::clone(&client.sender).subscribe(&SHARD_ID_1);
1548        drop(_token_2);
1549
1550        // create the server after the client is up
1551        let server_state = server_runtime.block_on(spawn_server(tcp_listener_stream));
1552
1553        server_runtime.block_on(async {
1554            // client connects automatically once the server is up
1555            poll_until_true(CONNECT_TIMEOUT, || {
1556                server_state.active_connections().len() == 1
1557            })
1558            .await;
1559
1560            // client rehydrated its subscriptions. notably, only includes the shard that
1561            // still has an active token
1562            poll_until_true(SUBSCRIPTIONS_TIMEOUT, || {
1563                server_state.shard_subscription_counts() == HashMap::from([(*SHARD_ID_0, 1)])
1564            })
1565            .await;
1566        });
1567
1568        // kill the server
1569        server_runtime.shutdown_timeout(SERVER_SHUTDOWN_TIMEOUT);
1570
1571        // client can still send requests without error
1572        let _token_2 = Arc::clone(&client.sender).subscribe(&SHARD_ID_1);
1573
1574        // create a new server
1575        let server_runtime = tokio::runtime::Runtime::new().expect("server runtime");
1576        let tcp_listener_stream = server_runtime.block_on(async {
1577            TcpListenerStream::new(
1578                TcpListener::bind(addr)
1579                    .await
1580                    .expect("can bind to previous addr"),
1581            )
1582        });
1583        let server_state = server_runtime.block_on(spawn_server(tcp_listener_stream));
1584
1585        server_runtime.block_on(async {
1586            // client automatically reconnects to new server
1587            poll_until_true(CONNECT_TIMEOUT, || {
1588                server_state.active_connections().len() == 1
1589            })
1590            .await;
1591
1592            // and rehydrates its subscriptions, including the new one that was sent
1593            // while the server was unavailable.
1594            poll_until_true(SUBSCRIPTIONS_TIMEOUT, || {
1595                server_state.shard_subscription_counts()
1596                    == HashMap::from([(*SHARD_ID_0, 1), (*SHARD_ID_1, 1)])
1597            })
1598            .await;
1599        });
1600    }
1601
1602    #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
1603    #[cfg_attr(miri, ignore)] // error: unsupported operation: can't call foreign function `socket` on OS `linux`
1604    async fn grpc_client_sender_subscription_tokens() {
1605        let metrics = Arc::new(Metrics::new(
1606            &test_persist_config(),
1607            &MetricsRegistry::new(),
1608        ));
1609
1610        let (addr, tcp_listener_stream) = new_tcp_listener().await;
1611        let server_state = spawn_server(tcp_listener_stream).await;
1612
1613        let client = GrpcPubSubClient::connect(
1614            PersistPubSubClientConfig {
1615                url: format!("http://{}", addr),
1616                caller_id: "client".to_string(),
1617                persist_cfg: test_persist_config(),
1618            },
1619            metrics,
1620        );
1621
1622        // our client connects
1623        poll_until_true(CONNECT_TIMEOUT, || {
1624            server_state.active_connections().len() == 1
1625        })
1626        .await;
1627
1628        // we can subscribe to a shard, receiving back a token
1629        let token = Arc::clone(&client.sender).subscribe(&SHARD_ID_0);
1630        poll_until_true(SUBSCRIPTIONS_TIMEOUT, || {
1631            server_state.shard_subscription_counts() == HashMap::from([(*SHARD_ID_0, 1)])
1632        })
1633        .await;
1634
1635        // dropping the token will unsubscribe our client
1636        drop(token);
1637        poll_until_true(SUBSCRIPTIONS_TIMEOUT, || {
1638            server_state.shard_subscription_counts() == HashMap::new()
1639        })
1640        .await;
1641
1642        // we can resubscribe to a shard
1643        let token = Arc::clone(&client.sender).subscribe(&SHARD_ID_0);
1644        poll_until_true(SUBSCRIPTIONS_TIMEOUT, || {
1645            server_state.shard_subscription_counts() == HashMap::from([(*SHARD_ID_0, 1)])
1646        })
1647        .await;
1648
1649        // we can subscribe many times idempotently, receiving back Arcs to the same token
1650        let token2 = Arc::clone(&client.sender).subscribe(&SHARD_ID_0);
1651        let token3 = Arc::clone(&client.sender).subscribe(&SHARD_ID_0);
1652        assert_eq!(Arc::strong_count(&token), 3);
1653        poll_until_true(SUBSCRIPTIONS_TIMEOUT, || {
1654            server_state.shard_subscription_counts() == HashMap::from([(*SHARD_ID_0, 1)])
1655        })
1656        .await;
1657
1658        // dropping all of the tokens will unsubscribe the shard
1659        drop(token);
1660        drop(token2);
1661        drop(token3);
1662        poll_until_true(SUBSCRIPTIONS_TIMEOUT, || {
1663            server_state.shard_subscription_counts() == HashMap::new()
1664        })
1665        .await;
1666
1667        // we can subscribe to many shards
1668        let _token0 = Arc::clone(&client.sender).subscribe(&SHARD_ID_0);
1669        let _token1 = Arc::clone(&client.sender).subscribe(&SHARD_ID_1);
1670        poll_until_true(SUBSCRIPTIONS_TIMEOUT, || {
1671            server_state.shard_subscription_counts()
1672                == HashMap::from([(*SHARD_ID_0, 1), (*SHARD_ID_1, 1)])
1673        })
1674        .await;
1675    }
1676
1677    #[mz_ore::test]
1678    #[cfg_attr(miri, ignore)] // error: unsupported operation: can't call foreign function `socket` on OS `linux`
1679    fn grpc_client_receiver() {
1680        let metrics = Arc::new(Metrics::new(
1681            &PersistConfig::new_for_tests(),
1682            &MetricsRegistry::new(),
1683        ));
1684        let server_runtime = tokio::runtime::Runtime::new().expect("server runtime");
1685        let client_runtime = tokio::runtime::Runtime::new().expect("client runtime");
1686        let (addr, tcp_listener_stream) = server_runtime.block_on(new_tcp_listener());
1687
1688        // create two clients, so we can test that broadcast messages are received by the other
1689        let mut client_1 = client_runtime.block_on(async {
1690            GrpcPubSubClient::connect(
1691                PersistPubSubClientConfig {
1692                    url: format!("http://{}", addr),
1693                    caller_id: "client_1".to_string(),
1694                    persist_cfg: test_persist_config(),
1695                },
1696                Arc::clone(&metrics),
1697            )
1698        });
1699        let mut client_2 = client_runtime.block_on(async {
1700            GrpcPubSubClient::connect(
1701                PersistPubSubClientConfig {
1702                    url: format!("http://{}", addr),
1703                    caller_id: "client_2".to_string(),
1704                    persist_cfg: test_persist_config(),
1705                },
1706                metrics,
1707            )
1708        });
1709
1710        // we can check our receiver output before connecting to the server.
1711        // these calls are race-y, since there's no guarantee on the time it
1712        // would take for a message to be received were one to have been sent,
1713        // but, better than nothing?
1714        assert_none!(client_1.receiver.next().now_or_never());
1715        assert_none!(client_2.receiver.next().now_or_never());
1716
1717        // start the server
1718        let server_state = server_runtime.block_on(spawn_server(tcp_listener_stream));
1719
1720        // wait until both clients are connected
1721        server_runtime.block_on(poll_until_true(CONNECT_TIMEOUT, || {
1722            server_state.active_connections().len() == 2
1723        }));
1724
1725        // no messages have been broadcast yet
1726        assert_none!(client_1.receiver.next().now_or_never());
1727        assert_none!(client_2.receiver.next().now_or_never());
1728
1729        // subscribe and send a diff
1730        let _token_client_1 = Arc::clone(&client_1.sender).subscribe(&SHARD_ID_0);
1731        let _token_client_2 = Arc::clone(&client_2.sender).subscribe(&SHARD_ID_0);
1732        server_runtime.block_on(poll_until_true(SUBSCRIPTIONS_TIMEOUT, || {
1733            server_state.shard_subscription_counts() == HashMap::from([(*SHARD_ID_0, 2)])
1734        }));
1735
1736        // the subscriber non-sender client receives the diff
1737        client_1.sender.push_diff(&SHARD_ID_0, &VERSIONED_DATA_1);
1738        assert_none!(client_1.receiver.next().now_or_never());
1739        client_runtime.block_on(async {
1740            assert_push(
1741                client_2.receiver.next().await.expect("has diff"),
1742                &SHARD_ID_0,
1743                &VERSIONED_DATA_1,
1744            )
1745        });
1746
1747        // kill the server
1748        server_runtime.shutdown_timeout(SERVER_SHUTDOWN_TIMEOUT);
1749
1750        // receivers can still be polled without error
1751        assert_none!(client_1.receiver.next().now_or_never());
1752        assert_none!(client_2.receiver.next().now_or_never());
1753
1754        // create a new server
1755        let server_runtime = tokio::runtime::Runtime::new().expect("server runtime");
1756        let tcp_listener_stream = server_runtime.block_on(async {
1757            TcpListenerStream::new(
1758                TcpListener::bind(addr)
1759                    .await
1760                    .expect("can bind to previous addr"),
1761            )
1762        });
1763        let server_state = server_runtime.block_on(spawn_server(tcp_listener_stream));
1764
1765        // client automatically reconnects to new server and rehydrates subscriptions
1766        server_runtime.block_on(async {
1767            poll_until_true(CONNECT_TIMEOUT, || {
1768                server_state.active_connections().len() == 2
1769            })
1770            .await;
1771            poll_until_true(SUBSCRIPTIONS_TIMEOUT, || {
1772                server_state.shard_subscription_counts() == HashMap::from([(*SHARD_ID_0, 2)])
1773            })
1774            .await;
1775        });
1776
1777        // pushing and receiving diffs works as expected.
1778        // this time we'll push from the other client.
1779        client_2.sender.push_diff(&SHARD_ID_0, &VERSIONED_DATA_0);
1780        client_runtime.block_on(async {
1781            assert_push(
1782                client_1.receiver.next().await.expect("has diff"),
1783                &SHARD_ID_0,
1784                &VERSIONED_DATA_0,
1785            )
1786        });
1787        assert_none!(client_2.receiver.next().now_or_never());
1788    }
1789
1790    async fn new_tcp_listener() -> (SocketAddr, TcpListenerStream) {
1791        let addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0));
1792        let tcp_listener = TcpListener::bind(addr).await.expect("tcp listener");
1793
1794        (
1795            tcp_listener.local_addr().expect("bound to local address"),
1796            TcpListenerStream::new(tcp_listener),
1797        )
1798    }
1799
1800    #[allow(clippy::unused_async)]
1801    async fn spawn_server(tcp_listener_stream: TcpListenerStream) -> Arc<PubSubState> {
1802        let server = PersistGrpcPubSubServer::new(&test_persist_config(), &MetricsRegistry::new());
1803        let server_state = Arc::clone(&server.state);
1804
1805        let _server_task = mz_ore::task::spawn(|| "server".to_string(), async move {
1806            server.serve_with_stream(tcp_listener_stream).await
1807        });
1808        server_state
1809    }
1810
1811    async fn poll_until_true<F>(timeout: Duration, f: F)
1812    where
1813        F: Fn() -> bool,
1814    {
1815        let now = Instant::now();
1816        loop {
1817            if f() {
1818                return;
1819            }
1820
1821            if now.elapsed() > timeout {
1822                panic!("timed out");
1823            }
1824
1825            tokio::time::sleep(Duration::from_millis(1)).await;
1826        }
1827    }
1828
1829    fn assert_push(message: ProtoPubSubMessage, shard: &ShardId, data: &VersionedData) {
1830        let message = message.message.expect("proto contains message");
1831        match message {
1832            Message::PushDiff(x) => {
1833                assert_eq!(x.shard_id, shard.into_proto());
1834                assert_eq!(x.seqno, data.seqno.into_proto());
1835                assert_eq!(x.diff, data.data);
1836            }
1837            Message::Subscribe(_) | Message::Unsubscribe(_) => panic!("unexpected message type"),
1838        };
1839    }
1840
1841    fn test_persist_config() -> PersistConfig {
1842        let cfg = PersistConfig::new_for_tests();
1843
1844        let mut updates = ConfigUpdates::default();
1845        updates.add(&PUBSUB_CLIENT_ENABLED, true);
1846        updates.add(&PUBSUB_RECONNECT_BACKOFF, Duration::ZERO);
1847        cfg.apply_from(&updates);
1848
1849        cfg
1850    }
1851}