mz_persist_client/
rpc.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! gRPC-based implementations of Persist PubSub client and server.
11
12use std::collections::BTreeMap;
13use std::collections::btree_map::Entry;
14use std::fmt::{Debug, Formatter};
15use std::net::SocketAddr;
16use std::pin::Pin;
17use std::str::FromStr;
18use std::sync::atomic::{AtomicUsize, Ordering};
19use std::sync::{Arc, Mutex, RwLock, Weak};
20use std::time::{Duration, Instant, SystemTime};
21
22use anyhow::{Error, anyhow};
23use async_trait::async_trait;
24use bytes::Bytes;
25use futures::Stream;
26use futures_util::StreamExt;
27use mz_dyncfg::Config;
28use mz_ore::cast::CastFrom;
29use mz_ore::collections::{HashMap, HashSet};
30use mz_ore::metrics::MetricsRegistry;
31use mz_ore::retry::RetryResult;
32use mz_ore::task::JoinHandle;
33use mz_persist::location::VersionedData;
34use mz_proto::{ProtoType, RustType};
35use prost::Message;
36use tokio::sync::mpsc::Sender;
37use tokio::sync::mpsc::error::TrySendError;
38use tokio_stream::wrappers::errors::BroadcastStreamRecvError;
39use tokio_stream::wrappers::{BroadcastStream, ReceiverStream};
40use tonic::metadata::{AsciiMetadataKey, AsciiMetadataValue, MetadataMap};
41use tonic::transport::Endpoint;
42use tonic::{Extensions, Request, Response, Status, Streaming};
43use tracing::{Instrument, debug, error, info, info_span, warn};
44
45use crate::ShardId;
46use crate::cache::{DynState, StateCache};
47use crate::cfg::PersistConfig;
48use crate::internal::metrics::{PubSubClientCallMetrics, PubSubServerMetrics};
49use crate::internal::service::proto_persist_pub_sub_client::ProtoPersistPubSubClient;
50use crate::internal::service::proto_persist_pub_sub_server::ProtoPersistPubSubServer;
51use crate::internal::service::{
52    ProtoPubSubMessage, ProtoPushDiff, ProtoSubscribe, ProtoUnsubscribe,
53    proto_persist_pub_sub_server, proto_pub_sub_message,
54};
55use crate::metrics::Metrics;
56
57/// Determines whether PubSub clients should connect to the PubSub server.
58pub(crate) const PUBSUB_CLIENT_ENABLED: Config<bool> = Config::new(
59    "persist_pubsub_client_enabled",
60    true,
61    "Whether to connect to the Persist PubSub service.",
62);
63
64/// For connected clients, determines whether to push state diffs to the PubSub
65/// server. For the server, determines whether to broadcast state diffs to
66/// subscribed clients.
67pub(crate) const PUBSUB_PUSH_DIFF_ENABLED: Config<bool> = Config::new(
68    "persist_pubsub_push_diff_enabled",
69    true,
70    "Whether to push state diffs to Persist PubSub.",
71);
72
73/// For connected clients, determines whether to push state diffs to the PubSub
74/// server. For the server, determines whether to broadcast state diffs to
75/// subscribed clients.
76pub(crate) const PUBSUB_SAME_PROCESS_DELEGATE_ENABLED: Config<bool> = Config::new(
77    "persist_pubsub_same_process_delegate_enabled",
78    true,
79    "Whether to push state diffs to Persist PubSub on the same process.",
80);
81
82/// Timeout per connection attempt to Persist PubSub service.
83pub(crate) const PUBSUB_CONNECT_ATTEMPT_TIMEOUT: Config<Duration> = Config::new(
84    "persist_pubsub_connect_attempt_timeout",
85    Duration::from_secs(5),
86    "Timeout per connection attempt to Persist PubSub service.",
87);
88
89/// Timeout per request attempt to Persist PubSub service.
90pub(crate) const PUBSUB_REQUEST_TIMEOUT: Config<Duration> = Config::new(
91    "persist_pubsub_request_timeout",
92    Duration::from_secs(5),
93    "Timeout per request attempt to Persist PubSub service.",
94);
95
96/// Maximum backoff when retrying connection establishment to Persist PubSub service.
97pub(crate) const PUBSUB_CONNECT_MAX_BACKOFF: Config<Duration> = Config::new(
98    "persist_pubsub_connect_max_backoff",
99    Duration::from_secs(60),
100    "Maximum backoff when retrying connection establishment to Persist PubSub service.",
101);
102
103/// Size of channel used to buffer send messages to PubSub service.
104pub(crate) const PUBSUB_CLIENT_SENDER_CHANNEL_SIZE: Config<usize> = Config::new(
105    "persist_pubsub_client_sender_channel_size",
106    25,
107    "Size of channel used to buffer send messages to PubSub service.",
108);
109
110/// Size of channel used to buffer received messages from PubSub service.
111pub(crate) const PUBSUB_CLIENT_RECEIVER_CHANNEL_SIZE: Config<usize> = Config::new(
112    "persist_pubsub_client_receiver_channel_size",
113    25,
114    "Size of channel used to buffer received messages from PubSub service.",
115);
116
117/// Size of channel used per connection to buffer broadcasted messages from PubSub server.
118pub(crate) const PUBSUB_SERVER_CONNECTION_CHANNEL_SIZE: Config<usize> = Config::new(
119    "persist_pubsub_server_connection_channel_size",
120    25,
121    "Size of channel used per connection to buffer broadcasted messages from PubSub server.",
122);
123
124/// Size of channel used by the state cache to broadcast shard state references.
125pub(crate) const PUBSUB_STATE_CACHE_SHARD_REF_CHANNEL_SIZE: Config<usize> = Config::new(
126    "persist_pubsub_state_cache_shard_ref_channel_size",
127    25,
128    "Size of channel used by the state cache to broadcast shard state references.",
129);
130
131/// Backoff after an established connection to Persist PubSub service fails.
132pub(crate) const PUBSUB_RECONNECT_BACKOFF: Config<Duration> = Config::new(
133    "persist_pubsub_reconnect_backoff",
134    Duration::from_secs(5),
135    "Backoff after an established connection to Persist PubSub service fails.",
136);
137
138/// Top-level Trait to create a PubSubClient.
139///
140/// Returns a [PubSubClientConnection] with a [PubSubSender] for issuing RPCs to the PubSub
141/// server, and a [PubSubReceiver] that receives messages, such as state diffs.
142pub trait PersistPubSubClient {
143    /// Receive handles with which to push and subscribe to diffs.
144    fn connect(
145        pubsub_config: PersistPubSubClientConfig,
146        metrics: Arc<Metrics>,
147    ) -> PubSubClientConnection;
148}
149
150/// Wrapper type for a matching [PubSubSender] and [PubSubReceiver] client pair.
151#[derive(Debug)]
152pub struct PubSubClientConnection {
153    /// The sender client to Persist PubSub.
154    pub sender: Arc<dyn PubSubSender>,
155    /// The receiver client to Persist PubSub.
156    pub receiver: Box<dyn PubSubReceiver>,
157}
158
159impl PubSubClientConnection {
160    /// Creates a new [PubSubClientConnection] from a matching [PubSubSender] and [PubSubReceiver].
161    pub fn new(sender: Arc<dyn PubSubSender>, receiver: Box<dyn PubSubReceiver>) -> Self {
162        Self { sender, receiver }
163    }
164
165    /// Creates a no-op [PubSubClientConnection] that neither sends nor receives messages.
166    pub fn noop() -> Self {
167        Self {
168            sender: Arc::new(NoopPubSubSender),
169            receiver: Box::new(futures::stream::empty()),
170        }
171    }
172}
173
174/// The public send-side client to Persist PubSub.
175pub trait PubSubSender: std::fmt::Debug + Send + Sync {
176    /// Push a diff to subscribers.
177    fn push_diff(&self, shard_id: &ShardId, diff: &VersionedData);
178
179    /// Subscribe the corresponding [PubSubReceiver] to diffs for the given shard.
180    ///
181    /// Returns a token that, when dropped, will unsubscribe the client from the
182    /// shard.
183    ///
184    /// If the client is already subscribed to the shard, repeated calls will make
185    /// no further calls to the server and instead return clones of the `Arc<ShardSubscriptionToken>`.
186    fn subscribe(self: Arc<Self>, shard_id: &ShardId) -> Arc<ShardSubscriptionToken>;
187}
188
189/// The internal send-side client trait to Persist PubSub, responsible for issuing RPCs
190/// to the PubSub service. This trait is separated out from [PubSubSender] to keep the
191/// client implementations straightforward, while offering a more ergonomic public API
192/// in [PubSubSender].
193trait PubSubSenderInternal: Debug + Send + Sync {
194    /// Push a diff to subscribers.
195    fn push_diff(&self, shard_id: &ShardId, diff: &VersionedData);
196
197    /// Subscribe the corresponding [PubSubReceiver] to diffs for the given shard.
198    ///
199    /// This call is idempotent and is a no-op for an already subscribed shard.
200    fn subscribe(&self, shard_id: &ShardId);
201
202    /// Unsubscribe the corresponding [PubSubReceiver] from diffs for the given shard.
203    ///
204    /// This call is idempotent and is a no-op for already unsubscribed shards.
205    fn unsubscribe(&self, shard_id: &ShardId);
206}
207
208/// The receive-side client to Persist PubSub.
209///
210/// Returns diffs (and maybe in the future, blobs) for any shards subscribed to
211/// by the corresponding `PubSubSender`.
212pub trait PubSubReceiver:
213    Stream<Item = ProtoPubSubMessage> + Send + Unpin + std::fmt::Debug
214{
215}
216
217impl<T> PubSubReceiver for T where
218    T: Stream<Item = ProtoPubSubMessage> + Send + Unpin + std::fmt::Debug
219{
220}
221
222/// A token corresponding to a subscription to diffs for a particular shard.
223///
224/// When dropped, the client that originated the token will be unsubscribed
225/// from further diffs to the shard.
226pub struct ShardSubscriptionToken {
227    pub(crate) shard_id: ShardId,
228    sender: Arc<dyn PubSubSenderInternal>,
229}
230
231impl Debug for ShardSubscriptionToken {
232    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
233        let ShardSubscriptionToken {
234            shard_id,
235            sender: _sender,
236        } = self;
237        write!(f, "ShardSubscriptionToken({})", shard_id)
238    }
239}
240
241impl Drop for ShardSubscriptionToken {
242    fn drop(&mut self) {
243        self.sender.unsubscribe(&self.shard_id);
244    }
245}
246
247/// A gRPC metadata key to indicate the caller id of a client.
248pub const PERSIST_PUBSUB_CALLER_KEY: &str = "persist-pubsub-caller-id";
249
250/// Client configuration for connecting to a remote PubSub server.
251#[derive(Debug)]
252pub struct PersistPubSubClientConfig {
253    /// Connection address for the pubsub server, e.g. `http://localhost:6879`
254    pub url: String,
255    /// A caller ID for the client. Used for debugging.
256    pub caller_id: String,
257    /// A copy of [PersistConfig]
258    pub persist_cfg: PersistConfig,
259}
260
261/// A [PersistPubSubClient] implementation backed by gRPC.
262///
263/// Returns a [PubSubClientConnection] backed by channels that submit and receive
264/// messages to and from a long-lived bidirectional gRPC stream. The gRPC stream
265/// will be transparently reestablished if the connection is lost.
266#[derive(Debug)]
267pub struct GrpcPubSubClient;
268
269impl GrpcPubSubClient {
270    async fn reconnect_to_server_forever(
271        send_requests: tokio::sync::broadcast::Sender<ProtoPubSubMessage>,
272        receiver_input: &tokio::sync::mpsc::Sender<ProtoPubSubMessage>,
273        sender: Arc<SubscriptionTrackingSender>,
274        metadata: MetadataMap,
275        config: PersistPubSubClientConfig,
276        metrics: Arc<Metrics>,
277    ) {
278        // Once enabled, the PubSub server cannot be disabled or otherwise
279        // reconfigured. So we wait for at least one configuration sync to
280        // complete. This gives `environmentd` at least one chance to update
281        // PubSub configuration parameters. See database-issues#7168 for details.
282        config.persist_cfg.configs_synced_once().await;
283
284        let mut is_first_connection_attempt = true;
285        loop {
286            let sender = Arc::clone(&sender);
287            metrics.pubsub_client.grpc_connection.connected.set(0);
288
289            if !PUBSUB_CLIENT_ENABLED.get(&config.persist_cfg) {
290                tokio::time::sleep(Duration::from_secs(5)).await;
291                continue;
292            }
293
294            // add a bit of backoff when reconnecting after some network/server failure
295            if is_first_connection_attempt {
296                is_first_connection_attempt = false;
297            } else {
298                tokio::time::sleep(PUBSUB_RECONNECT_BACKOFF.get(&config.persist_cfg)).await;
299            }
300
301            info!("Connecting to Persist PubSub: {}", config.url);
302            let client = mz_ore::retry::Retry::default()
303                .clamp_backoff(PUBSUB_CONNECT_MAX_BACKOFF.get(&config.persist_cfg))
304                .retry_async(|_| async {
305                    metrics
306                        .pubsub_client
307                        .grpc_connection
308                        .connect_call_attempt_count
309                        .inc();
310                    let endpoint = match Endpoint::from_str(&config.url) {
311                        Ok(endpoint) => endpoint,
312                        Err(err) => return RetryResult::FatalErr(err),
313                    };
314                    ProtoPersistPubSubClient::connect(
315                        endpoint
316                            .connect_timeout(
317                                PUBSUB_CONNECT_ATTEMPT_TIMEOUT.get(&config.persist_cfg),
318                            )
319                            .timeout(PUBSUB_REQUEST_TIMEOUT.get(&config.persist_cfg)),
320                    )
321                    .await
322                    .into()
323                })
324                .await;
325
326            let mut client = match client {
327                Ok(client) => client,
328                Err(err) => {
329                    error!("fatal error connecting to persist pubsub: {:?}", err);
330                    return;
331                }
332            };
333
334            metrics
335                .pubsub_client
336                .grpc_connection
337                .connection_established_count
338                .inc();
339            metrics.pubsub_client.grpc_connection.connected.set(1);
340
341            info!("Connected to Persist PubSub: {}", config.url);
342
343            let mut broadcast = BroadcastStream::new(send_requests.subscribe());
344            let broadcast_errors = metrics
345                .pubsub_client
346                .grpc_connection
347                .broadcast_recv_lagged_count
348                .clone();
349
350            // shard subscriptions are tracked by connection on the server, so if our
351            // gRPC stream is ever swapped out, we must inform the server which shards
352            // our client intended to be subscribed to.
353            let broadcast_messages = async_stream::stream! {
354                'reconnect: loop {
355                    // If we have active subscriptions, resend them.
356                    for id in sender.subscriptions() {
357                        debug!("re-subscribing to shard: {id}");
358                        yield create_request(proto_pub_sub_message::Message::Subscribe(ProtoSubscribe {
359                            shard_id: id.into_proto(),
360                        }));
361                    }
362
363                    // Forward on messages from the broadcast channel, reconnecting if necessary.
364                    while let Some(message) = broadcast.next().await {
365                        debug!("sending pubsub message: {:?}", message);
366                        match message {
367                            Ok(message) => yield message,
368                            Err(BroadcastStreamRecvError::Lagged(i)) => {
369                                broadcast_errors.inc_by(i);
370                                continue 'reconnect;
371                            }
372                        }
373                    }
374                    debug!("exhausted pubsub broadcast stream; shutting down");
375                    break;
376                }
377            };
378            let pubsub_request =
379                Request::from_parts(metadata.clone(), Extensions::default(), broadcast_messages);
380
381            let responses = match client.pub_sub(pubsub_request).await {
382                Ok(response) => response.into_inner(),
383                Err(err) => {
384                    warn!("pub_sub rpc error: {:?}", err);
385                    continue;
386                }
387            };
388
389            let stream_completed = GrpcPubSubClient::consume_grpc_stream(
390                responses,
391                receiver_input,
392                &config,
393                metrics.as_ref(),
394            )
395            .await;
396
397            match stream_completed {
398                // common case: reconnect due to some transient error
399                Ok(_) => continue,
400                // uncommon case: we should stop connecting to the PubSub server entirely.
401                // in practice, we should only see this during shut down.
402                Err(err) => {
403                    warn!("shutting down connection loop to Persist PubSub: {}", err);
404                    return;
405                }
406            }
407        }
408    }
409
410    async fn consume_grpc_stream(
411        mut responses: Streaming<ProtoPubSubMessage>,
412        receiver_input: &Sender<ProtoPubSubMessage>,
413        config: &PersistPubSubClientConfig,
414        metrics: &Metrics,
415    ) -> Result<(), Error> {
416        loop {
417            if !PUBSUB_CLIENT_ENABLED.get(&config.persist_cfg) {
418                return Ok(());
419            }
420
421            debug!("awaiting next pubsub response");
422            match responses.next().await {
423                Some(Ok(message)) => {
424                    debug!("received pubsub message: {:?}", message);
425                    match receiver_input.send(message).await {
426                        Ok(_) => {}
427                        // if the receiver has dropped, we can drop our
428                        // no-longer-needed grpc connection entirely.
429                        Err(err) => {
430                            return Err(anyhow!("closing pubsub grpc client connection: {}", err));
431                        }
432                    }
433                }
434                Some(Err(err)) => {
435                    metrics.pubsub_client.grpc_connection.grpc_error_count.inc();
436                    warn!("pubsub client error: {:?}", err);
437                    return Ok(());
438                }
439                None => return Ok(()),
440            }
441        }
442    }
443}
444
445impl PersistPubSubClient for GrpcPubSubClient {
446    fn connect(config: PersistPubSubClientConfig, metrics: Arc<Metrics>) -> PubSubClientConnection {
447        // Create a stable channel for our client to transmit message into our gRPC stream. We use a
448        // broadcast to allow us to create new Receivers on demand, in case the underlying gRPC stream
449        // is swapped out (e.g. due to connection failure). It is expected that only 1 Receiver is
450        // ever active at a given time.
451        let (send_requests, _) = tokio::sync::broadcast::channel(
452            PUBSUB_CLIENT_SENDER_CHANNEL_SIZE.get(&config.persist_cfg),
453        );
454        // Create a stable channel to receive messages from our gRPC stream. The input end lives inside
455        // a task that continuously reads from the active gRPC stream, decoupling the `PubSubReceiver`
456        // from the lifetime of a specific gRPC connection.
457        let (receiver_input, receiver_output) = tokio::sync::mpsc::channel(
458            PUBSUB_CLIENT_RECEIVER_CHANNEL_SIZE.get(&config.persist_cfg),
459        );
460
461        let sender = Arc::new(SubscriptionTrackingSender::new(Arc::new(
462            GrpcPubSubSender {
463                metrics: Arc::clone(&metrics),
464                requests: send_requests.clone(),
465            },
466        )));
467        let pubsub_sender = Arc::clone(&sender);
468        mz_ore::task::spawn(
469            || "persist::rpc::client::connection".to_string(),
470            async move {
471                let mut metadata = MetadataMap::new();
472                metadata.insert(
473                    AsciiMetadataKey::from_static(PERSIST_PUBSUB_CALLER_KEY),
474                    AsciiMetadataValue::try_from(&config.caller_id)
475                        .unwrap_or_else(|_| AsciiMetadataValue::from_static("unknown")),
476                );
477
478                GrpcPubSubClient::reconnect_to_server_forever(
479                    send_requests,
480                    &receiver_input,
481                    pubsub_sender,
482                    metadata,
483                    config,
484                    metrics,
485                )
486                .await;
487            },
488        );
489
490        PubSubClientConnection {
491            sender,
492            receiver: Box::new(ReceiverStream::new(receiver_output)),
493        }
494    }
495}
496
497/// An internal, gRPC-backed implementation of [PubSubSender].
498struct GrpcPubSubSender {
499    metrics: Arc<Metrics>,
500    requests: tokio::sync::broadcast::Sender<ProtoPubSubMessage>,
501}
502
503impl Debug for GrpcPubSubSender {
504    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
505        let GrpcPubSubSender {
506            metrics: _metrics,
507            requests: _requests,
508        } = self;
509
510        write!(f, "GrpcPubSubSender")
511    }
512}
513
514fn create_request(message: proto_pub_sub_message::Message) -> ProtoPubSubMessage {
515    let now = SystemTime::now()
516        .duration_since(SystemTime::UNIX_EPOCH)
517        .expect("failed to get millis since epoch");
518
519    ProtoPubSubMessage {
520        timestamp: Some(now.into_proto()),
521        message: Some(message),
522    }
523}
524
525impl GrpcPubSubSender {
526    fn send(&self, message: proto_pub_sub_message::Message, metrics: &PubSubClientCallMetrics) {
527        let size = message.encoded_len();
528
529        match self.requests.send(create_request(message)) {
530            Ok(_) => {
531                metrics.succeeded.inc();
532                metrics.bytes_sent.inc_by(u64::cast_from(size));
533            }
534            Err(err) => {
535                metrics.failed.inc();
536                debug!("error sending client message: {}", err);
537            }
538        }
539    }
540}
541
542impl PubSubSenderInternal for GrpcPubSubSender {
543    fn push_diff(&self, shard_id: &ShardId, diff: &VersionedData) {
544        self.send(
545            proto_pub_sub_message::Message::PushDiff(ProtoPushDiff {
546                shard_id: shard_id.into_proto(),
547                seqno: diff.seqno.into_proto(),
548                diff: diff.data.clone(),
549            }),
550            &self.metrics.pubsub_client.sender.push,
551        )
552    }
553
554    fn subscribe(&self, shard_id: &ShardId) {
555        self.send(
556            proto_pub_sub_message::Message::Subscribe(ProtoSubscribe {
557                shard_id: shard_id.into_proto(),
558            }),
559            &self.metrics.pubsub_client.sender.subscribe,
560        )
561    }
562
563    fn unsubscribe(&self, shard_id: &ShardId) {
564        self.send(
565            proto_pub_sub_message::Message::Unsubscribe(ProtoUnsubscribe {
566                shard_id: shard_id.into_proto(),
567            }),
568            &self.metrics.pubsub_client.sender.unsubscribe,
569        )
570    }
571}
572
573/// An wrapper for a [PubSubSenderInternal] that implements [PubSubSender]
574/// by maintaining a map of active shard subscriptions to their tokens.
575#[derive(Debug)]
576struct SubscriptionTrackingSender {
577    delegate: Arc<dyn PubSubSenderInternal>,
578    subscribes: Arc<Mutex<BTreeMap<ShardId, Weak<ShardSubscriptionToken>>>>,
579}
580
581impl SubscriptionTrackingSender {
582    fn new(sender: Arc<dyn PubSubSenderInternal>) -> Self {
583        Self {
584            delegate: sender,
585            subscribes: Default::default(),
586        }
587    }
588
589    fn subscriptions(&self) -> Vec<ShardId> {
590        let mut subscribes = self.subscribes.lock().expect("lock");
591        let mut out = Vec::with_capacity(subscribes.len());
592        subscribes.retain(|shard_id, token| {
593            if token.upgrade().is_none() {
594                false
595            } else {
596                debug!("reconnecting to: {}", shard_id);
597                out.push(*shard_id);
598                true
599            }
600        });
601        out
602    }
603}
604
605impl PubSubSender for SubscriptionTrackingSender {
606    fn push_diff(&self, shard_id: &ShardId, diff: &VersionedData) {
607        self.delegate.push_diff(shard_id, diff)
608    }
609
610    fn subscribe(self: Arc<Self>, shard_id: &ShardId) -> Arc<ShardSubscriptionToken> {
611        let mut subscribes = self.subscribes.lock().expect("lock");
612        if let Some(token) = subscribes.get(shard_id) {
613            match token.upgrade() {
614                None => assert!(subscribes.remove(shard_id).is_some()),
615                Some(token) => {
616                    return Arc::clone(&token);
617                }
618            }
619        }
620
621        let pubsub_sender = Arc::clone(&self.delegate);
622        let token = Arc::new(ShardSubscriptionToken {
623            shard_id: *shard_id,
624            sender: pubsub_sender,
625        });
626
627        assert!(
628            subscribes
629                .insert(*shard_id, Arc::downgrade(&token))
630                .is_none()
631        );
632
633        self.delegate.subscribe(shard_id);
634
635        token
636    }
637}
638
639/// A wrapper intended to provide client-side metrics for a connection
640/// that communicates directly with the server state, such as one created
641/// by [PersistGrpcPubSubServer::new_same_process_connection].
642#[derive(Debug)]
643pub struct MetricsSameProcessPubSubSender {
644    delegate_subscribe: bool,
645    metrics: Arc<Metrics>,
646    delegate: Arc<dyn PubSubSender>,
647}
648
649impl MetricsSameProcessPubSubSender {
650    /// Returns a new [MetricsSameProcessPubSubSender], wrapping the given
651    /// `Arc<dyn PubSubSender>`'s calls to provide client-side metrics.
652    pub fn new(
653        cfg: &PersistConfig,
654        pubsub_sender: Arc<dyn PubSubSender>,
655        metrics: Arc<Metrics>,
656    ) -> Self {
657        Self {
658            delegate_subscribe: PUBSUB_SAME_PROCESS_DELEGATE_ENABLED.get(cfg),
659            delegate: pubsub_sender,
660            metrics,
661        }
662    }
663}
664
665impl PubSubSender for MetricsSameProcessPubSubSender {
666    fn push_diff(&self, shard_id: &ShardId, diff: &VersionedData) {
667        self.delegate.push_diff(shard_id, diff);
668        self.metrics.pubsub_client.sender.push.succeeded.inc();
669    }
670
671    fn subscribe(self: Arc<Self>, shard_id: &ShardId) -> Arc<ShardSubscriptionToken> {
672        if self.delegate_subscribe {
673            let delegate = Arc::clone(&self.delegate);
674            delegate.subscribe(shard_id)
675        } else {
676            // Create a no-op token that does not subscribe nor unsubscribe.
677            // This is ideal for single-process persist setups, since the sender and
678            // receiver should already share a state cache... but if the diffs are
679            // generated remotely but applied on the server, this may cause us to fall
680            // back to polling consensus.
681            Arc::new(ShardSubscriptionToken {
682                shard_id: *shard_id,
683                sender: Arc::new(NoopPubSubSender),
684            })
685        }
686    }
687}
688
689#[derive(Debug)]
690pub(crate) struct NoopPubSubSender;
691
692impl PubSubSenderInternal for NoopPubSubSender {
693    fn push_diff(&self, _shard_id: &ShardId, _diff: &VersionedData) {}
694    fn subscribe(&self, _shard_id: &ShardId) {}
695    fn unsubscribe(&self, _shard_id: &ShardId) {}
696}
697
698impl PubSubSender for NoopPubSubSender {
699    fn push_diff(&self, _shard_id: &ShardId, _diff: &VersionedData) {}
700
701    fn subscribe(self: Arc<Self>, shard_id: &ShardId) -> Arc<ShardSubscriptionToken> {
702        Arc::new(ShardSubscriptionToken {
703            shard_id: *shard_id,
704            sender: self,
705        })
706    }
707}
708
709/// Spawns a Tokio task that consumes a [PubSubReceiver], applying its diffs to a [StateCache].
710pub(crate) fn subscribe_state_cache_to_pubsub(
711    cache: Arc<StateCache>,
712    mut pubsub_receiver: Box<dyn PubSubReceiver>,
713) -> JoinHandle<()> {
714    let mut state_refs: HashMap<ShardId, Weak<dyn DynState>> = HashMap::new();
715    let receiver_metrics = cache.metrics.pubsub_client.receiver.clone();
716
717    mz_ore::task::spawn(
718        || "persist::rpc::client::state_cache_diff_apply",
719        async move {
720            while let Some(msg) = pubsub_receiver.next().await {
721                match msg.message {
722                    Some(proto_pub_sub_message::Message::PushDiff(diff)) => {
723                        receiver_metrics.push_received.inc();
724                        let shard_id = diff.shard_id.into_rust().expect("valid shard id");
725                        let diff = VersionedData {
726                            seqno: diff.seqno.into_rust().expect("valid SeqNo"),
727                            data: diff.diff,
728                        };
729                        debug!(
730                            "applying pubsub diff {} {} {}",
731                            shard_id,
732                            diff.seqno,
733                            diff.data.len()
734                        );
735
736                        let mut pushed_diff = false;
737                        if let Some(state_ref) = state_refs.get(&shard_id) {
738                            // common case: we have a reference to the shard state already
739                            // and can apply our diff directly.
740                            if let Some(state) = state_ref.upgrade() {
741                                state.push_diff(diff.clone());
742                                pushed_diff = true;
743                                receiver_metrics.state_pushed_diff_fast_path.inc();
744                            }
745                        }
746
747                        if !pushed_diff {
748                            // uncommon case: we either don't have a reference yet, or ours
749                            // is out-of-date (e.g. the shard was dropped and then re-added
750                            // to StateCache). here we'll fetch the latest, try to apply the
751                            // diff again, and update our local reference.
752                            let state_ref = cache.get_state_weak(&shard_id);
753                            match state_ref {
754                                None => {
755                                    state_refs.remove(&shard_id);
756                                }
757                                Some(state_ref) => {
758                                    if let Some(state) = state_ref.upgrade() {
759                                        state.push_diff(diff);
760                                        pushed_diff = true;
761                                        state_refs.insert(shard_id, state_ref);
762                                    } else {
763                                        state_refs.remove(&shard_id);
764                                    }
765                                }
766                            }
767
768                            if pushed_diff {
769                                receiver_metrics.state_pushed_diff_slow_path_succeeded.inc();
770                            } else {
771                                receiver_metrics.state_pushed_diff_slow_path_failed.inc();
772                            }
773                        }
774
775                        if let Some(send_timestamp) = msg.timestamp {
776                            let send_timestamp =
777                                send_timestamp.into_rust().expect("valid timestamp");
778                            let now = SystemTime::now()
779                                .duration_since(SystemTime::UNIX_EPOCH)
780                                .expect("failed to get millis since epoch");
781                            receiver_metrics
782                                .approx_diff_latency_seconds
783                                .observe((now.saturating_sub(send_timestamp)).as_secs_f64());
784                        }
785                    }
786                    ref msg @ None | ref msg @ Some(_) => {
787                        warn!("pubsub client received unexpected message: {:?}", msg);
788                        receiver_metrics.unknown_message_received.inc();
789                    }
790                }
791            }
792        },
793    )
794}
795
796/// Internal state of a PubSub server implementation.
797#[derive(Debug)]
798pub(crate) struct PubSubState {
799    /// Assigns a unique ID to each incoming connection.
800    connection_id_counter: AtomicUsize,
801    /// Maintains a mapping of `ShardId --> [ConnectionId -> Tx]`.
802    shard_subscribers:
803        Arc<RwLock<BTreeMap<ShardId, BTreeMap<usize, Sender<Result<ProtoPubSubMessage, Status>>>>>>,
804    /// Active connections.
805    connections: Arc<RwLock<HashSet<usize>>>,
806    /// Server-side metrics.
807    metrics: Arc<PubSubServerMetrics>,
808}
809
810impl PubSubState {
811    fn new_connection(
812        self: Arc<Self>,
813        notifier: Sender<Result<ProtoPubSubMessage, Status>>,
814    ) -> PubSubConnection {
815        let connection_id = self.connection_id_counter.fetch_add(1, Ordering::SeqCst);
816        {
817            debug!("inserting connid: {}", connection_id);
818            let mut connections = self.connections.write().expect("lock");
819            assert!(connections.insert(connection_id));
820        }
821
822        self.metrics.active_connections.inc();
823        PubSubConnection {
824            connection_id,
825            notifier,
826            state: self,
827        }
828    }
829
830    fn remove_connection(&self, connection_id: usize) {
831        let now = Instant::now();
832
833        {
834            debug!("removing connid: {}", connection_id);
835            let mut connections = self.connections.write().expect("lock");
836            assert!(
837                connections.remove(&connection_id),
838                "unknown connection id: {}",
839                connection_id
840            );
841        }
842
843        {
844            let mut subscribers = self.shard_subscribers.write().expect("lock poisoned");
845            subscribers.retain(|_shard, connections_for_shard| {
846                connections_for_shard.remove(&connection_id);
847                !connections_for_shard.is_empty()
848            });
849        }
850
851        self.metrics
852            .connection_cleanup_seconds
853            .inc_by(now.elapsed().as_secs_f64());
854        self.metrics.active_connections.dec();
855    }
856
857    fn push_diff(&self, connection_id: usize, shard_id: &ShardId, data: &VersionedData) {
858        let now = Instant::now();
859        self.metrics.push_call_count.inc();
860
861        assert!(
862            self.connections
863                .read()
864                .expect("lock")
865                .contains(&connection_id),
866            "unknown connection id: {}",
867            connection_id
868        );
869
870        let subscribers = self.shard_subscribers.read().expect("lock poisoned");
871        if let Some(subscribed_connections) = subscribers.get(shard_id) {
872            let mut num_sent = 0;
873            let mut data_size = 0;
874
875            for (subscribed_conn_id, tx) in subscribed_connections {
876                // skip sending the diff back to the original sender
877                if *subscribed_conn_id == connection_id {
878                    continue;
879                }
880                debug!(
881                    "server forwarding req to conn {}: {} {} {}",
882                    subscribed_conn_id,
883                    &shard_id,
884                    data.seqno,
885                    data.data.len()
886                );
887                let req = create_request(proto_pub_sub_message::Message::PushDiff(ProtoPushDiff {
888                    seqno: data.seqno.into_proto(),
889                    shard_id: shard_id.to_string(),
890                    diff: Bytes::clone(&data.data),
891                }));
892                data_size = req.encoded_len();
893                match tx.try_send(Ok(req)) {
894                    Ok(_) => {
895                        num_sent += 1;
896                    }
897                    Err(TrySendError::Full(_)) => {
898                        self.metrics.broadcasted_diff_dropped_channel_full.inc();
899                    }
900                    Err(TrySendError::Closed(_)) => {}
901                };
902            }
903
904            self.metrics.broadcasted_diff_count.inc_by(num_sent);
905            self.metrics
906                .broadcasted_diff_bytes
907                .inc_by(num_sent * u64::cast_from(data_size));
908        }
909
910        self.metrics
911            .push_seconds
912            .inc_by(now.elapsed().as_secs_f64());
913    }
914
915    fn subscribe(
916        &self,
917        connection_id: usize,
918        notifier: Sender<Result<ProtoPubSubMessage, Status>>,
919        shard_id: &ShardId,
920    ) {
921        let now = Instant::now();
922        self.metrics.subscribe_call_count.inc();
923
924        assert!(
925            self.connections
926                .read()
927                .expect("lock")
928                .contains(&connection_id),
929            "unknown connection id: {}",
930            connection_id
931        );
932
933        {
934            let mut subscribed_shards = self.shard_subscribers.write().expect("lock poisoned");
935            subscribed_shards
936                .entry(*shard_id)
937                .or_default()
938                .insert(connection_id, notifier);
939        }
940
941        self.metrics
942            .subscribe_seconds
943            .inc_by(now.elapsed().as_secs_f64());
944    }
945
946    fn unsubscribe(&self, connection_id: usize, shard_id: &ShardId) {
947        let now = Instant::now();
948        self.metrics.unsubscribe_call_count.inc();
949
950        assert!(
951            self.connections
952                .read()
953                .expect("lock")
954                .contains(&connection_id),
955            "unknown connection id: {}",
956            connection_id
957        );
958
959        {
960            let mut subscribed_shards = self.shard_subscribers.write().expect("lock poisoned");
961            if let Entry::Occupied(mut entry) = subscribed_shards.entry(*shard_id) {
962                let subscribed_connections = entry.get_mut();
963                subscribed_connections.remove(&connection_id);
964
965                if subscribed_connections.is_empty() {
966                    entry.remove_entry();
967                }
968            }
969        }
970
971        self.metrics
972            .unsubscribe_seconds
973            .inc_by(now.elapsed().as_secs_f64());
974    }
975
976    #[cfg(test)]
977    fn new_for_test() -> Self {
978        Self {
979            connection_id_counter: AtomicUsize::new(0),
980            shard_subscribers: Default::default(),
981            connections: Default::default(),
982            metrics: Arc::new(PubSubServerMetrics::new(&MetricsRegistry::new())),
983        }
984    }
985
986    #[cfg(test)]
987    fn active_connections(&self) -> HashSet<usize> {
988        self.connections.read().expect("lock").clone()
989    }
990
991    #[cfg(test)]
992    fn subscriptions(&self, connection_id: usize) -> HashSet<ShardId> {
993        let mut shards = HashSet::new();
994
995        let subscribers = self.shard_subscribers.read().expect("lock");
996        for (shard, subscribed_connections) in subscribers.iter() {
997            if subscribed_connections.contains_key(&connection_id) {
998                shards.insert(*shard);
999            }
1000        }
1001
1002        shards
1003    }
1004
1005    #[cfg(test)]
1006    fn shard_subscription_counts(&self) -> mz_ore::collections::HashMap<ShardId, usize> {
1007        let mut shards = mz_ore::collections::HashMap::new();
1008
1009        let subscribers = self.shard_subscribers.read().expect("lock");
1010        for (shard, subscribed_connections) in subscribers.iter() {
1011            shards.insert(*shard, subscribed_connections.len());
1012        }
1013
1014        shards
1015    }
1016}
1017
1018/// A gRPC-based implementation of a Persist PubSub server.
1019#[derive(Debug)]
1020pub struct PersistGrpcPubSubServer {
1021    cfg: PersistConfig,
1022    state: Arc<PubSubState>,
1023}
1024
1025impl PersistGrpcPubSubServer {
1026    /// Creates a new [PersistGrpcPubSubServer].
1027    pub fn new(cfg: &PersistConfig, metrics_registry: &MetricsRegistry) -> Self {
1028        let metrics = PubSubServerMetrics::new(metrics_registry);
1029        let state = Arc::new(PubSubState {
1030            connection_id_counter: AtomicUsize::new(0),
1031            shard_subscribers: Default::default(),
1032            connections: Default::default(),
1033            metrics: Arc::new(metrics),
1034        });
1035
1036        PersistGrpcPubSubServer {
1037            cfg: cfg.clone(),
1038            state,
1039        }
1040    }
1041
1042    /// Creates a connection to [PersistGrpcPubSubServer] that is directly connected
1043    /// to the server state. Calls into this connection do not go over the network
1044    /// nor require message serde.
1045    pub fn new_same_process_connection(&self) -> PubSubClientConnection {
1046        let (tx, rx) =
1047            tokio::sync::mpsc::channel(PUBSUB_CLIENT_RECEIVER_CHANNEL_SIZE.get(&self.cfg));
1048        let sender: Arc<dyn PubSubSender> = Arc::new(SubscriptionTrackingSender::new(Arc::new(
1049            Arc::clone(&self.state).new_connection(tx),
1050        )));
1051
1052        PubSubClientConnection {
1053            sender,
1054            receiver: Box::new(
1055                ReceiverStream::new(rx).map(|x| x.expect("cannot receive grpc errors locally")),
1056            ),
1057        }
1058    }
1059
1060    /// Starts the gRPC server. Consumes `self` and runs until the task is cancelled.
1061    pub async fn serve(self, listen_addr: SocketAddr) -> Result<(), anyhow::Error> {
1062        // Increase the default message decoding limit to avoid unnecessary panics
1063        tonic::transport::Server::builder()
1064            .add_service(ProtoPersistPubSubServer::new(self).max_decoding_message_size(usize::MAX))
1065            .serve(listen_addr)
1066            .await?;
1067        Ok(())
1068    }
1069
1070    /// Starts the gRPC server with the given listener stream.
1071    /// Consumes `self` and runs until the task is cancelled.
1072    pub async fn serve_with_stream(
1073        self,
1074        listener: tokio_stream::wrappers::TcpListenerStream,
1075    ) -> Result<(), anyhow::Error> {
1076        tonic::transport::Server::builder()
1077            .add_service(ProtoPersistPubSubServer::new(self))
1078            .serve_with_incoming(listener)
1079            .await?;
1080        Ok(())
1081    }
1082}
1083
1084#[async_trait]
1085impl proto_persist_pub_sub_server::ProtoPersistPubSub for PersistGrpcPubSubServer {
1086    type PubSubStream = Pin<Box<dyn Stream<Item = Result<ProtoPubSubMessage, Status>> + Send>>;
1087
1088    #[mz_ore::instrument(name = "persist::rpc::server", level = "info")]
1089    async fn pub_sub(
1090        &self,
1091        request: Request<Streaming<ProtoPubSubMessage>>,
1092    ) -> Result<Response<Self::PubSubStream>, Status> {
1093        let caller_id = request
1094            .metadata()
1095            .get(AsciiMetadataKey::from_static(PERSIST_PUBSUB_CALLER_KEY))
1096            .map(|key| key.to_str().ok())
1097            .flatten()
1098            .map(|key| key.to_string())
1099            .unwrap_or_else(|| "unknown".to_string());
1100        info!("Received Persist PubSub connection from: {:?}", caller_id);
1101
1102        let mut in_stream = request.into_inner();
1103        let (tx, rx) =
1104            tokio::sync::mpsc::channel(PUBSUB_SERVER_CONNECTION_CHANNEL_SIZE.get(&self.cfg));
1105
1106        let caller = caller_id.clone();
1107        let cfg = Arc::clone(&self.cfg.configs);
1108        let server_state = Arc::clone(&self.state);
1109        // this spawn here to cleanup after connection error / disconnect, otherwise the stream
1110        // would not be polled after the connection drops. in our case, we want to clear the
1111        // connection and its subscriptions from our shared state when it drops.
1112        let connection_span = info_span!("connection", caller_id);
1113        mz_ore::task::spawn(
1114            || format!("persist_pubsub_connection({})", caller),
1115            async move {
1116                let connection = server_state.new_connection(tx);
1117                while let Some(result) = in_stream.next().await {
1118                    let req = match result {
1119                        Ok(req) => req,
1120                        Err(err) => {
1121                            warn!("pubsub connection err: {}", err);
1122                            break;
1123                        }
1124                    };
1125
1126                    match req.message {
1127                        None => {
1128                            warn!("received empty message from: {}", caller_id);
1129                        }
1130                        Some(proto_pub_sub_message::Message::PushDiff(req)) => {
1131                            let shard_id = req.shard_id.parse().expect("valid shard id");
1132                            let diff = VersionedData {
1133                                seqno: req.seqno.into_rust().expect("valid seqno"),
1134                                data: req.diff.clone(),
1135                            };
1136                            if PUBSUB_PUSH_DIFF_ENABLED.get(&cfg) {
1137                                connection.push_diff(&shard_id, &diff);
1138                            }
1139                        }
1140                        Some(proto_pub_sub_message::Message::Subscribe(diff)) => {
1141                            let shard_id = diff.shard_id.parse().expect("valid shard id");
1142                            connection.subscribe(&shard_id);
1143                        }
1144                        Some(proto_pub_sub_message::Message::Unsubscribe(diff)) => {
1145                            let shard_id = diff.shard_id.parse().expect("valid shard id");
1146                            connection.unsubscribe(&shard_id);
1147                        }
1148                    }
1149                }
1150
1151                info!("Persist PubSub connection ended: {:?}", caller_id);
1152            }
1153            .instrument(connection_span),
1154        );
1155
1156        let out_stream: Self::PubSubStream = Box::pin(ReceiverStream::new(rx));
1157        Ok(Response::new(out_stream))
1158    }
1159}
1160
1161/// An active connection managed by [PubSubState].
1162///
1163/// When dropped, removes itself from [PubSubState], clearing all of its subscriptions.
1164#[derive(Debug)]
1165pub(crate) struct PubSubConnection {
1166    connection_id: usize,
1167    notifier: Sender<Result<ProtoPubSubMessage, Status>>,
1168    state: Arc<PubSubState>,
1169}
1170
1171impl PubSubSenderInternal for PubSubConnection {
1172    fn push_diff(&self, shard_id: &ShardId, diff: &VersionedData) {
1173        self.state.push_diff(self.connection_id, shard_id, diff)
1174    }
1175
1176    fn subscribe(&self, shard_id: &ShardId) {
1177        self.state
1178            .subscribe(self.connection_id, self.notifier.clone(), shard_id)
1179    }
1180
1181    fn unsubscribe(&self, shard_id: &ShardId) {
1182        self.state.unsubscribe(self.connection_id, shard_id)
1183    }
1184}
1185
1186impl Drop for PubSubConnection {
1187    fn drop(&mut self) {
1188        self.state.remove_connection(self.connection_id)
1189    }
1190}
1191
1192#[cfg(test)]
1193mod pubsub_state {
1194    use std::str::FromStr;
1195    use std::sync::Arc;
1196    use std::sync::LazyLock;
1197
1198    use bytes::Bytes;
1199    use mz_ore::collections::HashSet;
1200    use mz_persist::location::{SeqNo, VersionedData};
1201    use mz_proto::RustType;
1202    use tokio::sync::mpsc::Receiver;
1203    use tokio::sync::mpsc::error::TryRecvError;
1204    use tonic::Status;
1205
1206    use crate::ShardId;
1207    use crate::internal::service::ProtoPubSubMessage;
1208    use crate::internal::service::proto_pub_sub_message::Message;
1209    use crate::rpc::{PubSubSenderInternal, PubSubState};
1210
1211    static SHARD_ID_0: LazyLock<ShardId> =
1212        LazyLock::new(|| ShardId::from_str("s00000000-0000-0000-0000-000000000000").unwrap());
1213    static SHARD_ID_1: LazyLock<ShardId> =
1214        LazyLock::new(|| ShardId::from_str("s11111111-1111-1111-1111-111111111111").unwrap());
1215
1216    const VERSIONED_DATA_0: VersionedData = VersionedData {
1217        seqno: SeqNo(0),
1218        data: Bytes::from_static(&[0, 1, 2, 3]),
1219    };
1220
1221    const VERSIONED_DATA_1: VersionedData = VersionedData {
1222        seqno: SeqNo(1),
1223        data: Bytes::from_static(&[4, 5, 6, 7]),
1224    };
1225
1226    #[mz_ore::test]
1227    #[should_panic(expected = "unknown connection id: 100")]
1228    fn test_zero_connections_push_diff() {
1229        let state = Arc::new(PubSubState::new_for_test());
1230        state.push_diff(100, &SHARD_ID_0, &VERSIONED_DATA_0);
1231    }
1232
1233    #[mz_ore::test]
1234    #[should_panic(expected = "unknown connection id: 100")]
1235    fn test_zero_connections_subscribe() {
1236        let state = Arc::new(PubSubState::new_for_test());
1237        let (tx, _) = tokio::sync::mpsc::channel(100);
1238        state.subscribe(100, tx, &SHARD_ID_0);
1239    }
1240
1241    #[mz_ore::test]
1242    #[should_panic(expected = "unknown connection id: 100")]
1243    fn test_zero_connections_unsubscribe() {
1244        let state = Arc::new(PubSubState::new_for_test());
1245        state.unsubscribe(100, &SHARD_ID_0);
1246    }
1247
1248    #[mz_ore::test]
1249    #[should_panic(expected = "unknown connection id: 100")]
1250    fn test_zero_connections_remove() {
1251        let state = Arc::new(PubSubState::new_for_test());
1252        state.remove_connection(100)
1253    }
1254
1255    #[mz_ore::test]
1256    fn test_single_connection() {
1257        let state = Arc::new(PubSubState::new_for_test());
1258
1259        let (tx, mut rx) = tokio::sync::mpsc::channel(100);
1260        let connection = Arc::clone(&state).new_connection(tx);
1261
1262        assert_eq!(
1263            state.active_connections(),
1264            HashSet::from([connection.connection_id])
1265        );
1266
1267        // no messages should have been broadcasted yet
1268        assert!(matches!(rx.try_recv(), Err(TryRecvError::Empty)));
1269
1270        connection.push_diff(
1271            &SHARD_ID_0,
1272            &VersionedData {
1273                seqno: SeqNo::minimum(),
1274                data: Bytes::new(),
1275            },
1276        );
1277
1278        // server should not broadcast a message back to originating client
1279        assert!(matches!(rx.try_recv(), Err(TryRecvError::Empty)));
1280
1281        // a connection can subscribe to a shard
1282        connection.subscribe(&SHARD_ID_0);
1283        assert_eq!(
1284            state.subscriptions(connection.connection_id),
1285            HashSet::from([SHARD_ID_0.clone()])
1286        );
1287
1288        // a connection can unsubscribe
1289        connection.unsubscribe(&SHARD_ID_0);
1290        assert!(state.subscriptions(connection.connection_id).is_empty());
1291
1292        // a connection can subscribe to many shards
1293        connection.subscribe(&SHARD_ID_0);
1294        connection.subscribe(&SHARD_ID_1);
1295        assert_eq!(
1296            state.subscriptions(connection.connection_id),
1297            HashSet::from([*SHARD_ID_0, *SHARD_ID_1])
1298        );
1299
1300        // and to a single shard many times idempotently
1301        connection.subscribe(&SHARD_ID_0);
1302        connection.subscribe(&SHARD_ID_0);
1303        assert_eq!(
1304            state.subscriptions(connection.connection_id),
1305            HashSet::from([*SHARD_ID_0, *SHARD_ID_1])
1306        );
1307
1308        // dropping the connection should unsubscribe all shards and unregister the connection
1309        let connection_id = connection.connection_id;
1310        drop(connection);
1311        assert!(state.subscriptions(connection_id).is_empty());
1312        assert!(state.active_connections().is_empty());
1313    }
1314
1315    #[mz_ore::test]
1316    fn test_many_connection() {
1317        let state = Arc::new(PubSubState::new_for_test());
1318
1319        let (tx1, mut rx1) = tokio::sync::mpsc::channel(100);
1320        let conn1 = Arc::clone(&state).new_connection(tx1);
1321
1322        let (tx2, mut rx2) = tokio::sync::mpsc::channel(100);
1323        let conn2 = Arc::clone(&state).new_connection(tx2);
1324
1325        let (tx3, mut rx3) = tokio::sync::mpsc::channel(100);
1326        let conn3 = Arc::clone(&state).new_connection(tx3);
1327
1328        conn1.subscribe(&SHARD_ID_0);
1329        conn2.subscribe(&SHARD_ID_0);
1330        conn2.subscribe(&SHARD_ID_1);
1331
1332        assert_eq!(
1333            state.active_connections(),
1334            HashSet::from([
1335                conn1.connection_id,
1336                conn2.connection_id,
1337                conn3.connection_id
1338            ])
1339        );
1340
1341        // broadcast a diff to a shard subscribed to by several connections
1342        conn3.push_diff(&SHARD_ID_0, &VERSIONED_DATA_0);
1343        assert_push(&mut rx1, &SHARD_ID_0, &VERSIONED_DATA_0);
1344        assert_push(&mut rx2, &SHARD_ID_0, &VERSIONED_DATA_0);
1345        assert!(matches!(rx3.try_recv(), Err(TryRecvError::Empty)));
1346
1347        // broadcast a diff shared by publisher. it should not receive the diff back.
1348        conn1.push_diff(&SHARD_ID_0, &VERSIONED_DATA_0);
1349        assert!(matches!(rx1.try_recv(), Err(TryRecvError::Empty)));
1350        assert_push(&mut rx2, &SHARD_ID_0, &VERSIONED_DATA_0);
1351        assert!(matches!(rx3.try_recv(), Err(TryRecvError::Empty)));
1352
1353        // broadcast a diff to a shard subscribed to by one connection
1354        conn3.push_diff(&SHARD_ID_1, &VERSIONED_DATA_1);
1355        assert!(matches!(rx1.try_recv(), Err(TryRecvError::Empty)));
1356        assert_push(&mut rx2, &SHARD_ID_1, &VERSIONED_DATA_1);
1357        assert!(matches!(rx3.try_recv(), Err(TryRecvError::Empty)));
1358
1359        // broadcast a diff to a shard subscribed to by no connections
1360        conn2.unsubscribe(&SHARD_ID_1);
1361        conn3.push_diff(&SHARD_ID_1, &VERSIONED_DATA_1);
1362        assert!(matches!(rx1.try_recv(), Err(TryRecvError::Empty)));
1363        assert!(matches!(rx2.try_recv(), Err(TryRecvError::Empty)));
1364        assert!(matches!(rx3.try_recv(), Err(TryRecvError::Empty)));
1365
1366        // dropping connections unsubscribes them
1367        let conn1_id = conn1.connection_id;
1368        drop(conn1);
1369        conn3.push_diff(&SHARD_ID_0, &VERSIONED_DATA_0);
1370        assert!(matches!(rx1.try_recv(), Err(TryRecvError::Disconnected)));
1371        assert_push(&mut rx2, &SHARD_ID_0, &VERSIONED_DATA_0);
1372        assert!(matches!(rx3.try_recv(), Err(TryRecvError::Empty)));
1373
1374        assert!(state.subscriptions(conn1_id).is_empty());
1375        assert_eq!(
1376            state.subscriptions(conn2.connection_id),
1377            HashSet::from([*SHARD_ID_0])
1378        );
1379        assert_eq!(state.subscriptions(conn3.connection_id), HashSet::new());
1380        assert_eq!(
1381            state.active_connections(),
1382            HashSet::from([conn2.connection_id, conn3.connection_id])
1383        );
1384    }
1385
1386    fn assert_push(
1387        rx: &mut Receiver<Result<ProtoPubSubMessage, Status>>,
1388        shard: &ShardId,
1389        data: &VersionedData,
1390    ) {
1391        let message = rx
1392            .try_recv()
1393            .expect("message in channel")
1394            .expect("pubsub")
1395            .message
1396            .expect("proto contains message");
1397        match message {
1398            Message::PushDiff(x) => {
1399                assert_eq!(x.shard_id, shard.into_proto());
1400                assert_eq!(x.seqno, data.seqno.into_proto());
1401                assert_eq!(x.diff, data.data);
1402            }
1403            Message::Subscribe(_) | Message::Unsubscribe(_) => panic!("unexpected message type"),
1404        };
1405    }
1406}
1407
1408#[cfg(test)]
1409mod grpc {
1410    use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
1411    use std::str::FromStr;
1412    use std::sync::Arc;
1413    use std::time::{Duration, Instant};
1414
1415    use bytes::Bytes;
1416    use futures_util::FutureExt;
1417    use mz_dyncfg::ConfigUpdates;
1418    use mz_ore::assert_none;
1419    use mz_ore::collections::HashMap;
1420    use mz_ore::metrics::MetricsRegistry;
1421    use mz_persist::location::{SeqNo, VersionedData};
1422    use mz_proto::RustType;
1423    use std::sync::LazyLock;
1424    use tokio::net::TcpListener;
1425    use tokio_stream::StreamExt;
1426    use tokio_stream::wrappers::TcpListenerStream;
1427
1428    use crate::ShardId;
1429    use crate::cfg::PersistConfig;
1430    use crate::internal::service::ProtoPubSubMessage;
1431    use crate::internal::service::proto_pub_sub_message::Message;
1432    use crate::metrics::Metrics;
1433    use crate::rpc::{
1434        GrpcPubSubClient, PUBSUB_CLIENT_ENABLED, PUBSUB_RECONNECT_BACKOFF, PersistGrpcPubSubServer,
1435        PersistPubSubClient, PersistPubSubClientConfig, PubSubState,
1436    };
1437
1438    static SHARD_ID_0: LazyLock<ShardId> =
1439        LazyLock::new(|| ShardId::from_str("s00000000-0000-0000-0000-000000000000").unwrap());
1440    static SHARD_ID_1: LazyLock<ShardId> =
1441        LazyLock::new(|| ShardId::from_str("s11111111-1111-1111-1111-111111111111").unwrap());
1442    const VERSIONED_DATA_0: VersionedData = VersionedData {
1443        seqno: SeqNo(0),
1444        data: Bytes::from_static(&[0, 1, 2, 3]),
1445    };
1446    const VERSIONED_DATA_1: VersionedData = VersionedData {
1447        seqno: SeqNo(1),
1448        data: Bytes::from_static(&[4, 5, 6, 7]),
1449    };
1450
1451    const CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
1452    const SUBSCRIPTIONS_TIMEOUT: Duration = Duration::from_secs(3);
1453    const SERVER_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(2);
1454
1455    // NB: we use separate runtimes for client and server throughout these tests to cleanly drop
1456    // ALL tasks (including spawned child tasks) associated with one end of a connection, to most
1457    // closely model an actual disconnect.
1458
1459    #[mz_ore::test]
1460    #[cfg_attr(miri, ignore)] // error: unsupported operation: can't call foreign function `socket` on OS `linux`
1461    fn grpc_server() {
1462        let metrics = Arc::new(Metrics::new(
1463            &test_persist_config(),
1464            &MetricsRegistry::new(),
1465        ));
1466        let server_runtime = tokio::runtime::Runtime::new().expect("server runtime");
1467        let client_runtime = tokio::runtime::Runtime::new().expect("client runtime");
1468
1469        // start the server
1470        let (addr, tcp_listener_stream) = server_runtime.block_on(new_tcp_listener());
1471        let server_state = server_runtime.block_on(spawn_server(tcp_listener_stream));
1472
1473        // start a client.
1474        {
1475            let _guard = client_runtime.enter();
1476            mz_ore::task::spawn(|| "client".to_string(), async move {
1477                let client = GrpcPubSubClient::connect(
1478                    PersistPubSubClientConfig {
1479                        url: format!("http://{}", addr),
1480                        caller_id: "client".to_string(),
1481                        persist_cfg: test_persist_config(),
1482                    },
1483                    metrics,
1484                );
1485                let _token = client.sender.subscribe(&SHARD_ID_0);
1486                tokio::time::sleep(Duration::MAX).await;
1487            });
1488        }
1489
1490        // wait until the client is connected and subscribed
1491        server_runtime.block_on(async {
1492            poll_until_true(CONNECT_TIMEOUT, || {
1493                server_state.active_connections().len() == 1
1494            })
1495            .await;
1496            poll_until_true(SUBSCRIPTIONS_TIMEOUT, || {
1497                server_state.shard_subscription_counts() == HashMap::from([(*SHARD_ID_0, 1)])
1498            })
1499            .await
1500        });
1501
1502        // drop the client
1503        client_runtime.shutdown_timeout(SERVER_SHUTDOWN_TIMEOUT);
1504
1505        // server should notice the client dropping and clean up its state
1506        server_runtime.block_on(async {
1507            poll_until_true(CONNECT_TIMEOUT, || {
1508                server_state.active_connections().is_empty()
1509            })
1510            .await;
1511            poll_until_true(SUBSCRIPTIONS_TIMEOUT, || {
1512                server_state.shard_subscription_counts() == HashMap::new()
1513            })
1514            .await
1515        });
1516    }
1517
1518    #[mz_ore::test]
1519    #[cfg_attr(miri, ignore)] // error: unsupported operation: can't call foreign function `socket` on OS `linux`
1520    fn grpc_client_sender_reconnects() {
1521        let metrics = Arc::new(Metrics::new(
1522            &test_persist_config(),
1523            &MetricsRegistry::new(),
1524        ));
1525        let server_runtime = tokio::runtime::Runtime::new().expect("server runtime");
1526        let client_runtime = tokio::runtime::Runtime::new().expect("client runtime");
1527        let (addr, tcp_listener_stream) = server_runtime.block_on(new_tcp_listener());
1528
1529        // start a client
1530        let client = client_runtime.block_on(async {
1531            GrpcPubSubClient::connect(
1532                PersistPubSubClientConfig {
1533                    url: format!("http://{}", addr),
1534                    caller_id: "client".to_string(),
1535                    persist_cfg: test_persist_config(),
1536                },
1537                metrics,
1538            )
1539        });
1540
1541        // we can subscribe before connecting to the pubsub server
1542        let _token = Arc::clone(&client.sender).subscribe(&SHARD_ID_0);
1543        // we can subscribe and unsubscribe before connecting to the pubsub server
1544        let _token_2 = Arc::clone(&client.sender).subscribe(&SHARD_ID_1);
1545        drop(_token_2);
1546
1547        // create the server after the client is up
1548        let server_state = server_runtime.block_on(spawn_server(tcp_listener_stream));
1549
1550        server_runtime.block_on(async {
1551            // client connects automatically once the server is up
1552            poll_until_true(CONNECT_TIMEOUT, || {
1553                server_state.active_connections().len() == 1
1554            })
1555            .await;
1556
1557            // client rehydrated its subscriptions. notably, only includes the shard that
1558            // still has an active token
1559            poll_until_true(SUBSCRIPTIONS_TIMEOUT, || {
1560                server_state.shard_subscription_counts() == HashMap::from([(*SHARD_ID_0, 1)])
1561            })
1562            .await;
1563        });
1564
1565        // kill the server
1566        server_runtime.shutdown_timeout(SERVER_SHUTDOWN_TIMEOUT);
1567
1568        // client can still send requests without error
1569        let _token_2 = Arc::clone(&client.sender).subscribe(&SHARD_ID_1);
1570
1571        // create a new server
1572        let server_runtime = tokio::runtime::Runtime::new().expect("server runtime");
1573        let tcp_listener_stream = server_runtime.block_on(async {
1574            TcpListenerStream::new(
1575                TcpListener::bind(addr)
1576                    .await
1577                    .expect("can bind to previous addr"),
1578            )
1579        });
1580        let server_state = server_runtime.block_on(spawn_server(tcp_listener_stream));
1581
1582        server_runtime.block_on(async {
1583            // client automatically reconnects to new server
1584            poll_until_true(CONNECT_TIMEOUT, || {
1585                server_state.active_connections().len() == 1
1586            })
1587            .await;
1588
1589            // and rehydrates its subscriptions, including the new one that was sent
1590            // while the server was unavailable.
1591            poll_until_true(SUBSCRIPTIONS_TIMEOUT, || {
1592                server_state.shard_subscription_counts()
1593                    == HashMap::from([(*SHARD_ID_0, 1), (*SHARD_ID_1, 1)])
1594            })
1595            .await;
1596        });
1597    }
1598
1599    #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
1600    #[cfg_attr(miri, ignore)] // error: unsupported operation: can't call foreign function `socket` on OS `linux`
1601    async fn grpc_client_sender_subscription_tokens() {
1602        let metrics = Arc::new(Metrics::new(
1603            &test_persist_config(),
1604            &MetricsRegistry::new(),
1605        ));
1606
1607        let (addr, tcp_listener_stream) = new_tcp_listener().await;
1608        let server_state = spawn_server(tcp_listener_stream).await;
1609
1610        let client = GrpcPubSubClient::connect(
1611            PersistPubSubClientConfig {
1612                url: format!("http://{}", addr),
1613                caller_id: "client".to_string(),
1614                persist_cfg: test_persist_config(),
1615            },
1616            metrics,
1617        );
1618
1619        // our client connects
1620        poll_until_true(CONNECT_TIMEOUT, || {
1621            server_state.active_connections().len() == 1
1622        })
1623        .await;
1624
1625        // we can subscribe to a shard, receiving back a token
1626        let token = Arc::clone(&client.sender).subscribe(&SHARD_ID_0);
1627        poll_until_true(SUBSCRIPTIONS_TIMEOUT, || {
1628            server_state.shard_subscription_counts() == HashMap::from([(*SHARD_ID_0, 1)])
1629        })
1630        .await;
1631
1632        // dropping the token will unsubscribe our client
1633        drop(token);
1634        poll_until_true(SUBSCRIPTIONS_TIMEOUT, || {
1635            server_state.shard_subscription_counts() == HashMap::new()
1636        })
1637        .await;
1638
1639        // we can resubscribe to a shard
1640        let token = Arc::clone(&client.sender).subscribe(&SHARD_ID_0);
1641        poll_until_true(SUBSCRIPTIONS_TIMEOUT, || {
1642            server_state.shard_subscription_counts() == HashMap::from([(*SHARD_ID_0, 1)])
1643        })
1644        .await;
1645
1646        // we can subscribe many times idempotently, receiving back Arcs to the same token
1647        let token2 = Arc::clone(&client.sender).subscribe(&SHARD_ID_0);
1648        let token3 = Arc::clone(&client.sender).subscribe(&SHARD_ID_0);
1649        assert_eq!(Arc::strong_count(&token), 3);
1650        poll_until_true(SUBSCRIPTIONS_TIMEOUT, || {
1651            server_state.shard_subscription_counts() == HashMap::from([(*SHARD_ID_0, 1)])
1652        })
1653        .await;
1654
1655        // dropping all of the tokens will unsubscribe the shard
1656        drop(token);
1657        drop(token2);
1658        drop(token3);
1659        poll_until_true(SUBSCRIPTIONS_TIMEOUT, || {
1660            server_state.shard_subscription_counts() == HashMap::new()
1661        })
1662        .await;
1663
1664        // we can subscribe to many shards
1665        let _token0 = Arc::clone(&client.sender).subscribe(&SHARD_ID_0);
1666        let _token1 = Arc::clone(&client.sender).subscribe(&SHARD_ID_1);
1667        poll_until_true(SUBSCRIPTIONS_TIMEOUT, || {
1668            server_state.shard_subscription_counts()
1669                == HashMap::from([(*SHARD_ID_0, 1), (*SHARD_ID_1, 1)])
1670        })
1671        .await;
1672    }
1673
1674    #[mz_ore::test]
1675    #[cfg_attr(miri, ignore)] // error: unsupported operation: can't call foreign function `socket` on OS `linux`
1676    fn grpc_client_receiver() {
1677        let metrics = Arc::new(Metrics::new(
1678            &PersistConfig::new_for_tests(),
1679            &MetricsRegistry::new(),
1680        ));
1681        let server_runtime = tokio::runtime::Runtime::new().expect("server runtime");
1682        let client_runtime = tokio::runtime::Runtime::new().expect("client runtime");
1683        let (addr, tcp_listener_stream) = server_runtime.block_on(new_tcp_listener());
1684
1685        // create two clients, so we can test that broadcast messages are received by the other
1686        let mut client_1 = client_runtime.block_on(async {
1687            GrpcPubSubClient::connect(
1688                PersistPubSubClientConfig {
1689                    url: format!("http://{}", addr),
1690                    caller_id: "client_1".to_string(),
1691                    persist_cfg: test_persist_config(),
1692                },
1693                Arc::clone(&metrics),
1694            )
1695        });
1696        let mut client_2 = client_runtime.block_on(async {
1697            GrpcPubSubClient::connect(
1698                PersistPubSubClientConfig {
1699                    url: format!("http://{}", addr),
1700                    caller_id: "client_2".to_string(),
1701                    persist_cfg: test_persist_config(),
1702                },
1703                metrics,
1704            )
1705        });
1706
1707        // we can check our receiver output before connecting to the server.
1708        // these calls are race-y, since there's no guarantee on the time it
1709        // would take for a message to be received were one to have been sent,
1710        // but, better than nothing?
1711        assert_none!(client_1.receiver.next().now_or_never());
1712        assert_none!(client_2.receiver.next().now_or_never());
1713
1714        // start the server
1715        let server_state = server_runtime.block_on(spawn_server(tcp_listener_stream));
1716
1717        // wait until both clients are connected
1718        server_runtime.block_on(poll_until_true(CONNECT_TIMEOUT, || {
1719            server_state.active_connections().len() == 2
1720        }));
1721
1722        // no messages have been broadcast yet
1723        assert_none!(client_1.receiver.next().now_or_never());
1724        assert_none!(client_2.receiver.next().now_or_never());
1725
1726        // subscribe and send a diff
1727        let _token_client_1 = Arc::clone(&client_1.sender).subscribe(&SHARD_ID_0);
1728        let _token_client_2 = Arc::clone(&client_2.sender).subscribe(&SHARD_ID_0);
1729        server_runtime.block_on(poll_until_true(SUBSCRIPTIONS_TIMEOUT, || {
1730            server_state.shard_subscription_counts() == HashMap::from([(*SHARD_ID_0, 2)])
1731        }));
1732
1733        // the subscriber non-sender client receives the diff
1734        client_1.sender.push_diff(&SHARD_ID_0, &VERSIONED_DATA_1);
1735        assert_none!(client_1.receiver.next().now_or_never());
1736        client_runtime.block_on(async {
1737            assert_push(
1738                client_2.receiver.next().await.expect("has diff"),
1739                &SHARD_ID_0,
1740                &VERSIONED_DATA_1,
1741            )
1742        });
1743
1744        // kill the server
1745        server_runtime.shutdown_timeout(SERVER_SHUTDOWN_TIMEOUT);
1746
1747        // receivers can still be polled without error
1748        assert_none!(client_1.receiver.next().now_or_never());
1749        assert_none!(client_2.receiver.next().now_or_never());
1750
1751        // create a new server
1752        let server_runtime = tokio::runtime::Runtime::new().expect("server runtime");
1753        let tcp_listener_stream = server_runtime.block_on(async {
1754            TcpListenerStream::new(
1755                TcpListener::bind(addr)
1756                    .await
1757                    .expect("can bind to previous addr"),
1758            )
1759        });
1760        let server_state = server_runtime.block_on(spawn_server(tcp_listener_stream));
1761
1762        // client automatically reconnects to new server and rehydrates subscriptions
1763        server_runtime.block_on(async {
1764            poll_until_true(CONNECT_TIMEOUT, || {
1765                server_state.active_connections().len() == 2
1766            })
1767            .await;
1768            poll_until_true(SUBSCRIPTIONS_TIMEOUT, || {
1769                server_state.shard_subscription_counts() == HashMap::from([(*SHARD_ID_0, 2)])
1770            })
1771            .await;
1772        });
1773
1774        // pushing and receiving diffs works as expected.
1775        // this time we'll push from the other client.
1776        client_2.sender.push_diff(&SHARD_ID_0, &VERSIONED_DATA_0);
1777        client_runtime.block_on(async {
1778            assert_push(
1779                client_1.receiver.next().await.expect("has diff"),
1780                &SHARD_ID_0,
1781                &VERSIONED_DATA_0,
1782            )
1783        });
1784        assert_none!(client_2.receiver.next().now_or_never());
1785    }
1786
1787    async fn new_tcp_listener() -> (SocketAddr, TcpListenerStream) {
1788        let addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0));
1789        let tcp_listener = TcpListener::bind(addr).await.expect("tcp listener");
1790
1791        (
1792            tcp_listener.local_addr().expect("bound to local address"),
1793            TcpListenerStream::new(tcp_listener),
1794        )
1795    }
1796
1797    #[allow(clippy::unused_async)]
1798    async fn spawn_server(tcp_listener_stream: TcpListenerStream) -> Arc<PubSubState> {
1799        let server = PersistGrpcPubSubServer::new(&test_persist_config(), &MetricsRegistry::new());
1800        let server_state = Arc::clone(&server.state);
1801
1802        let _server_task = mz_ore::task::spawn(|| "server".to_string(), async move {
1803            server.serve_with_stream(tcp_listener_stream).await
1804        });
1805        server_state
1806    }
1807
1808    async fn poll_until_true<F>(timeout: Duration, f: F)
1809    where
1810        F: Fn() -> bool,
1811    {
1812        let now = Instant::now();
1813        loop {
1814            if f() {
1815                return;
1816            }
1817
1818            if now.elapsed() > timeout {
1819                panic!("timed out");
1820            }
1821
1822            tokio::time::sleep(Duration::from_millis(1)).await;
1823        }
1824    }
1825
1826    fn assert_push(message: ProtoPubSubMessage, shard: &ShardId, data: &VersionedData) {
1827        let message = message.message.expect("proto contains message");
1828        match message {
1829            Message::PushDiff(x) => {
1830                assert_eq!(x.shard_id, shard.into_proto());
1831                assert_eq!(x.seqno, data.seqno.into_proto());
1832                assert_eq!(x.diff, data.data);
1833            }
1834            Message::Subscribe(_) | Message::Unsubscribe(_) => panic!("unexpected message type"),
1835        };
1836    }
1837
1838    fn test_persist_config() -> PersistConfig {
1839        let cfg = PersistConfig::new_for_tests();
1840
1841        let mut updates = ConfigUpdates::default();
1842        updates.add(&PUBSUB_CLIENT_ENABLED, true);
1843        updates.add(&PUBSUB_RECONNECT_BACKOFF, Duration::ZERO);
1844        cfg.apply_from(&updates);
1845
1846        cfg
1847    }
1848}