Skip to main content

mz_persist_client/
rpc.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! gRPC-based implementations of Persist PubSub client and server.
11
12use std::collections::BTreeMap;
13use std::collections::btree_map::Entry;
14use std::fmt::{Debug, Formatter};
15use std::net::SocketAddr;
16use std::pin::Pin;
17use std::str::FromStr;
18use std::sync::atomic::{AtomicUsize, Ordering};
19use std::sync::{Arc, Mutex, RwLock, Weak};
20use std::time::{Duration, Instant, SystemTime};
21
22use anyhow::{Error, anyhow};
23use async_trait::async_trait;
24use bytes::Bytes;
25use futures::Stream;
26use futures_util::StreamExt;
27use mz_dyncfg::Config;
28use mz_ore::cast::CastFrom;
29use mz_ore::collections::{HashMap, HashSet};
30use mz_ore::metrics::MetricsRegistry;
31use mz_ore::retry::RetryResult;
32use mz_ore::task::JoinHandle;
33use mz_persist::location::VersionedData;
34use mz_proto::{ProtoType, RustType};
35use prost::Message;
36use tokio::sync::mpsc::Sender;
37use tokio::sync::mpsc::error::TrySendError;
38use tokio_stream::wrappers::errors::BroadcastStreamRecvError;
39use tokio_stream::wrappers::{BroadcastStream, ReceiverStream};
40use tonic::metadata::{AsciiMetadataKey, AsciiMetadataValue, MetadataMap};
41use tonic::transport::Endpoint;
42use tonic::{Extensions, Request, Response, Status, Streaming};
43use tracing::{Instrument, debug, error, info, info_span, warn};
44
45use crate::ShardId;
46use crate::cache::{DynState, StateCache};
47use crate::cfg::PersistConfig;
48use crate::internal::metrics::{PubSubClientCallMetrics, PubSubServerMetrics};
49use crate::internal::service::proto_persist_pub_sub_client::ProtoPersistPubSubClient;
50use crate::internal::service::proto_persist_pub_sub_server::ProtoPersistPubSubServer;
51use crate::internal::service::{
52    ProtoPubSubMessage, ProtoPushDiff, ProtoSubscribe, ProtoUnsubscribe,
53    proto_persist_pub_sub_server, proto_pub_sub_message,
54};
55use crate::metrics::Metrics;
56
57/// Determines whether PubSub clients should connect to the PubSub server.
58pub(crate) const PUBSUB_CLIENT_ENABLED: Config<bool> = Config::new(
59    "persist_pubsub_client_enabled",
60    true,
61    "Whether to connect to the Persist PubSub service.",
62);
63
64/// For connected clients, determines whether to push state diffs to the PubSub
65/// server. For the server, determines whether to broadcast state diffs to
66/// subscribed clients.
67pub(crate) const PUBSUB_PUSH_DIFF_ENABLED: Config<bool> = Config::new(
68    "persist_pubsub_push_diff_enabled",
69    true,
70    "Whether to push state diffs to Persist PubSub.",
71);
72
73/// For connected clients, determines whether to push state diffs to the PubSub
74/// server. For the server, determines whether to broadcast state diffs to
75/// subscribed clients.
76pub(crate) const PUBSUB_SAME_PROCESS_DELEGATE_ENABLED: Config<bool> = Config::new(
77    "persist_pubsub_same_process_delegate_enabled",
78    true,
79    "Whether to push state diffs to Persist PubSub on the same process.",
80);
81
82/// Timeout per connection attempt to Persist PubSub service.
83pub(crate) const PUBSUB_CONNECT_ATTEMPT_TIMEOUT: Config<Duration> = Config::new(
84    "persist_pubsub_connect_attempt_timeout",
85    Duration::from_secs(5),
86    "Timeout per connection attempt to Persist PubSub service.",
87);
88
89/// Timeout per request attempt to Persist PubSub service.
90pub(crate) const PUBSUB_REQUEST_TIMEOUT: Config<Duration> = Config::new(
91    "persist_pubsub_request_timeout",
92    Duration::from_secs(5),
93    "Timeout per request attempt to Persist PubSub service.",
94);
95
96/// Maximum backoff when retrying connection establishment to Persist PubSub service.
97pub(crate) const PUBSUB_CONNECT_MAX_BACKOFF: Config<Duration> = Config::new(
98    "persist_pubsub_connect_max_backoff",
99    Duration::from_secs(60),
100    "Maximum backoff when retrying connection establishment to Persist PubSub service.",
101);
102
103/// Size of channel used to buffer send messages to PubSub service.
104pub(crate) const PUBSUB_CLIENT_SENDER_CHANNEL_SIZE: Config<usize> = Config::new(
105    "persist_pubsub_client_sender_channel_size",
106    25,
107    "Size of channel used to buffer send messages to PubSub service.",
108);
109
110/// Size of channel used to buffer received messages from PubSub service.
111pub(crate) const PUBSUB_CLIENT_RECEIVER_CHANNEL_SIZE: Config<usize> = Config::new(
112    "persist_pubsub_client_receiver_channel_size",
113    25,
114    "Size of channel used to buffer received messages from PubSub service.",
115);
116
117/// Size of channel used per connection to buffer broadcasted messages from PubSub server.
118pub(crate) const PUBSUB_SERVER_CONNECTION_CHANNEL_SIZE: Config<usize> = Config::new(
119    "persist_pubsub_server_connection_channel_size",
120    25,
121    "Size of channel used per connection to buffer broadcasted messages from PubSub server.",
122);
123
124/// Size of channel used by the state cache to broadcast shard state references.
125pub(crate) const PUBSUB_STATE_CACHE_SHARD_REF_CHANNEL_SIZE: Config<usize> = Config::new(
126    "persist_pubsub_state_cache_shard_ref_channel_size",
127    25,
128    "Size of channel used by the state cache to broadcast shard state references.",
129);
130
131/// Backoff after an established connection to Persist PubSub service fails.
132pub(crate) const PUBSUB_RECONNECT_BACKOFF: Config<Duration> = Config::new(
133    "persist_pubsub_reconnect_backoff",
134    Duration::from_secs(5),
135    "Backoff after an established connection to Persist PubSub service fails.",
136);
137
138/// Max message size, used to configure gRPC servers and clients.
139///
140/// While `max_encoding_message_size` defaults to `usize::MAX`, `max_decoding_message_size` only
141/// defaults to 4MB, so we bump it to avoid protocol errors.
142const MAX_GRPC_MESSAGE_SIZE: usize = usize::MAX;
143
144/// Top-level Trait to create a PubSubClient.
145///
146/// Returns a [PubSubClientConnection] with a [PubSubSender] for issuing RPCs to the PubSub
147/// server, and a [PubSubReceiver] that receives messages, such as state diffs.
148pub trait PersistPubSubClient {
149    /// Receive handles with which to push and subscribe to diffs.
150    fn connect(
151        pubsub_config: PersistPubSubClientConfig,
152        metrics: Arc<Metrics>,
153    ) -> PubSubClientConnection;
154}
155
156/// Wrapper type for a matching [PubSubSender] and [PubSubReceiver] client pair.
157#[derive(Debug)]
158pub struct PubSubClientConnection {
159    /// The sender client to Persist PubSub.
160    pub sender: Arc<dyn PubSubSender>,
161    /// The receiver client to Persist PubSub.
162    pub receiver: Box<dyn PubSubReceiver>,
163}
164
165impl PubSubClientConnection {
166    /// Creates a new [PubSubClientConnection] from a matching [PubSubSender] and [PubSubReceiver].
167    pub fn new(sender: Arc<dyn PubSubSender>, receiver: Box<dyn PubSubReceiver>) -> Self {
168        Self { sender, receiver }
169    }
170
171    /// Creates a no-op [PubSubClientConnection] that neither sends nor receives messages.
172    pub fn noop() -> Self {
173        Self {
174            sender: Arc::new(NoopPubSubSender),
175            receiver: Box::new(futures::stream::empty()),
176        }
177    }
178}
179
180/// The public send-side client to Persist PubSub.
181pub trait PubSubSender: std::fmt::Debug + Send + Sync {
182    /// Push a diff to subscribers.
183    fn push_diff(&self, shard_id: &ShardId, diff: &VersionedData);
184
185    /// Subscribe the corresponding [PubSubReceiver] to diffs for the given shard.
186    ///
187    /// Returns a token that, when dropped, will unsubscribe the client from the
188    /// shard.
189    ///
190    /// If the client is already subscribed to the shard, repeated calls will make
191    /// no further calls to the server and instead return clones of the `Arc<ShardSubscriptionToken>`.
192    fn subscribe(self: Arc<Self>, shard_id: &ShardId) -> Arc<ShardSubscriptionToken>;
193}
194
195/// The internal send-side client trait to Persist PubSub, responsible for issuing RPCs
196/// to the PubSub service. This trait is separated out from [PubSubSender] to keep the
197/// client implementations straightforward, while offering a more ergonomic public API
198/// in [PubSubSender].
199trait PubSubSenderInternal: Debug + Send + Sync {
200    /// Push a diff to subscribers.
201    fn push_diff(&self, shard_id: &ShardId, diff: &VersionedData);
202
203    /// Subscribe the corresponding [PubSubReceiver] to diffs for the given shard.
204    ///
205    /// This call is idempotent and is a no-op for an already subscribed shard.
206    fn subscribe(&self, shard_id: &ShardId);
207
208    /// Unsubscribe the corresponding [PubSubReceiver] from diffs for the given shard.
209    ///
210    /// This call is idempotent and is a no-op for already unsubscribed shards.
211    fn unsubscribe(&self, shard_id: &ShardId);
212}
213
214/// The receive-side client to Persist PubSub.
215///
216/// Returns diffs (and maybe in the future, blobs) for any shards subscribed to
217/// by the corresponding `PubSubSender`.
218pub trait PubSubReceiver:
219    Stream<Item = ProtoPubSubMessage> + Send + Unpin + std::fmt::Debug
220{
221}
222
223impl<T> PubSubReceiver for T where
224    T: Stream<Item = ProtoPubSubMessage> + Send + Unpin + std::fmt::Debug
225{
226}
227
228/// A token corresponding to a subscription to diffs for a particular shard.
229///
230/// When dropped, the client that originated the token will be unsubscribed
231/// from further diffs to the shard.
232pub struct ShardSubscriptionToken {
233    pub(crate) shard_id: ShardId,
234    sender: Arc<dyn PubSubSenderInternal>,
235}
236
237impl Debug for ShardSubscriptionToken {
238    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
239        let ShardSubscriptionToken {
240            shard_id,
241            sender: _sender,
242        } = self;
243        write!(f, "ShardSubscriptionToken({})", shard_id)
244    }
245}
246
247impl Drop for ShardSubscriptionToken {
248    fn drop(&mut self) {
249        self.sender.unsubscribe(&self.shard_id);
250    }
251}
252
253/// A gRPC metadata key to indicate the caller id of a client.
254pub const PERSIST_PUBSUB_CALLER_KEY: &str = "persist-pubsub-caller-id";
255
256/// Client configuration for connecting to a remote PubSub server.
257#[derive(Debug)]
258pub struct PersistPubSubClientConfig {
259    /// Connection address for the pubsub server, e.g. `http://localhost:6879`
260    pub url: String,
261    /// A caller ID for the client. Used for debugging.
262    pub caller_id: String,
263    /// A copy of [PersistConfig]
264    pub persist_cfg: PersistConfig,
265}
266
267/// A [PersistPubSubClient] implementation backed by gRPC.
268///
269/// Returns a [PubSubClientConnection] backed by channels that submit and receive
270/// messages to and from a long-lived bidirectional gRPC stream. The gRPC stream
271/// will be transparently reestablished if the connection is lost.
272#[derive(Debug)]
273pub struct GrpcPubSubClient;
274
275impl GrpcPubSubClient {
276    async fn reconnect_to_server_forever(
277        send_requests: tokio::sync::broadcast::Sender<ProtoPubSubMessage>,
278        receiver_input: &tokio::sync::mpsc::Sender<ProtoPubSubMessage>,
279        sender: Arc<SubscriptionTrackingSender>,
280        metadata: MetadataMap,
281        config: PersistPubSubClientConfig,
282        metrics: Arc<Metrics>,
283    ) {
284        // Once enabled, the PubSub server cannot be disabled or otherwise
285        // reconfigured. So we wait for at least one configuration sync to
286        // complete. This gives `environmentd` at least one chance to update
287        // PubSub configuration parameters. See database-issues#7168 for details.
288        config.persist_cfg.configs_synced_once().await;
289
290        let mut is_first_connection_attempt = true;
291        loop {
292            let sender = Arc::clone(&sender);
293            metrics.pubsub_client.grpc_connection.connected.set(0);
294
295            if !PUBSUB_CLIENT_ENABLED.get(&config.persist_cfg) {
296                tokio::time::sleep(Duration::from_secs(5)).await;
297                continue;
298            }
299
300            // add a bit of backoff when reconnecting after some network/server failure
301            if is_first_connection_attempt {
302                is_first_connection_attempt = false;
303            } else {
304                tokio::time::sleep(PUBSUB_RECONNECT_BACKOFF.get(&config.persist_cfg)).await;
305            }
306
307            info!("Connecting to Persist PubSub: {}", config.url);
308            let client = mz_ore::retry::Retry::default()
309                .clamp_backoff(PUBSUB_CONNECT_MAX_BACKOFF.get(&config.persist_cfg))
310                .retry_async(|_| async {
311                    metrics
312                        .pubsub_client
313                        .grpc_connection
314                        .connect_call_attempt_count
315                        .inc();
316                    let endpoint = match Endpoint::from_str(&config.url) {
317                        Ok(endpoint) => endpoint,
318                        Err(err) => return RetryResult::FatalErr(err),
319                    };
320                    ProtoPersistPubSubClient::connect(
321                        endpoint
322                            .connect_timeout(
323                                PUBSUB_CONNECT_ATTEMPT_TIMEOUT.get(&config.persist_cfg),
324                            )
325                            .timeout(PUBSUB_REQUEST_TIMEOUT.get(&config.persist_cfg)),
326                    )
327                    .await
328                    .into()
329                })
330                .await;
331
332            let mut client = match client {
333                Ok(client) => client.max_decoding_message_size(MAX_GRPC_MESSAGE_SIZE),
334                Err(err) => {
335                    error!("fatal error connecting to persist pubsub: {:?}", err);
336                    return;
337                }
338            };
339
340            metrics
341                .pubsub_client
342                .grpc_connection
343                .connection_established_count
344                .inc();
345            metrics.pubsub_client.grpc_connection.connected.set(1);
346
347            info!("Connected to Persist PubSub: {}", config.url);
348
349            let mut broadcast = BroadcastStream::new(send_requests.subscribe());
350            let broadcast_errors = metrics
351                .pubsub_client
352                .grpc_connection
353                .broadcast_recv_lagged_count
354                .clone();
355
356            // `client.pub_sub(...)` starts a hyper background task reading from the
357            // `broadcast_messages` stream, to serve the HTTP2 connection. The broadcast stream
358            // doesn't normally terminate, which means the HTTP2 connection doesn't terminate
359            // either. We set up a cancelation token to force termination and avoid a connection
360            // leak.
361            let (cancel_tx, cancel_rx) = tokio::sync::oneshot::channel::<()>();
362
363            // shard subscriptions are tracked by connection on the server, so if our
364            // gRPC stream is ever swapped out, we must inform the server which shards
365            // our client intended to be subscribed to.
366            let broadcast_messages = async_stream::stream! {
367                let mut cancel_rx = std::pin::pin!(cancel_rx);
368                'reconnect: loop {
369                    // If we have active subscriptions, resend them.
370                    for id in sender.subscriptions() {
371                        debug!("re-subscribing to shard: {id}");
372                        let msg = proto_pub_sub_message::Message::Subscribe(
373                            ProtoSubscribe {
374                                shard_id: id.into_proto(),
375                            },
376                        );
377                        yield create_request(msg);
378                    }
379
380                    // Forward on messages from the broadcast channel, reconnecting if necessary.
381                    loop {
382                        tokio::select! {
383                            message = broadcast.next() => {
384                                debug!("sending pubsub message: {:?}", message);
385                                match message {
386                                    Some(Ok(message)) => yield message,
387                                    Some(Err(BroadcastStreamRecvError::Lagged(i))) => {
388                                        broadcast_errors.inc_by(i);
389                                        continue 'reconnect;
390                                    }
391                                    None => {
392                                        debug!("exhausted pubsub broadcast stream; shutting down");
393                                        return;
394                                    }
395                                }
396                            }
397                            _ = &mut cancel_rx => {
398                                debug!("pubsub broadcast stream cancelled; shutting down");
399                                return;
400                            }
401                        }
402                    }
403                }
404            };
405            let pubsub_request =
406                Request::from_parts(metadata.clone(), Extensions::default(), broadcast_messages);
407
408            let responses = match client.pub_sub(pubsub_request).await {
409                Ok(response) => response.into_inner(),
410                Err(err) => {
411                    warn!("pub_sub rpc error: {:?}", err);
412                    continue;
413                }
414            };
415
416            let stream_completed = GrpcPubSubClient::consume_grpc_stream(
417                responses,
418                receiver_input,
419                &config,
420                metrics.as_ref(),
421            )
422            .await;
423
424            drop(cancel_tx);
425
426            match stream_completed {
427                // common case: reconnect due to some transient error
428                Ok(_) => continue,
429                // uncommon case: we should stop connecting to the PubSub server entirely.
430                // in practice, we should only see this during shut down.
431                Err(err) => {
432                    warn!("shutting down connection loop to Persist PubSub: {}", err);
433                    return;
434                }
435            }
436        }
437    }
438
439    async fn consume_grpc_stream(
440        mut responses: Streaming<ProtoPubSubMessage>,
441        receiver_input: &Sender<ProtoPubSubMessage>,
442        config: &PersistPubSubClientConfig,
443        metrics: &Metrics,
444    ) -> Result<(), Error> {
445        loop {
446            if !PUBSUB_CLIENT_ENABLED.get(&config.persist_cfg) {
447                return Ok(());
448            }
449
450            debug!("awaiting next pubsub response");
451            match responses.next().await {
452                Some(Ok(message)) => {
453                    debug!("received pubsub message: {:?}", message);
454                    match receiver_input.send(message).await {
455                        Ok(_) => {}
456                        // if the receiver has dropped, we can drop our
457                        // no-longer-needed grpc connection entirely.
458                        Err(err) => {
459                            return Err(anyhow!("closing pubsub grpc client connection: {}", err));
460                        }
461                    }
462                }
463                Some(Err(err)) => {
464                    metrics.pubsub_client.grpc_connection.grpc_error_count.inc();
465                    warn!("pubsub client error: {:?}", err);
466                    return Ok(());
467                }
468                None => return Ok(()),
469            }
470        }
471    }
472}
473
474impl PersistPubSubClient for GrpcPubSubClient {
475    fn connect(config: PersistPubSubClientConfig, metrics: Arc<Metrics>) -> PubSubClientConnection {
476        // Create a stable channel for our client to transmit message into our gRPC stream. We use a
477        // broadcast to allow us to create new Receivers on demand, in case the underlying gRPC stream
478        // is swapped out (e.g. due to connection failure). It is expected that only 1 Receiver is
479        // ever active at a given time.
480        let (send_requests, _) = tokio::sync::broadcast::channel(
481            PUBSUB_CLIENT_SENDER_CHANNEL_SIZE.get(&config.persist_cfg),
482        );
483        // Create a stable channel to receive messages from our gRPC stream. The input end lives inside
484        // a task that continuously reads from the active gRPC stream, decoupling the `PubSubReceiver`
485        // from the lifetime of a specific gRPC connection.
486        let (receiver_input, receiver_output) = tokio::sync::mpsc::channel(
487            PUBSUB_CLIENT_RECEIVER_CHANNEL_SIZE.get(&config.persist_cfg),
488        );
489
490        let sender = Arc::new(SubscriptionTrackingSender::new(Arc::new(
491            GrpcPubSubSender {
492                metrics: Arc::clone(&metrics),
493                requests: send_requests.clone(),
494            },
495        )));
496        let pubsub_sender = Arc::clone(&sender);
497        mz_ore::task::spawn(
498            || "persist::rpc::client::connection".to_string(),
499            async move {
500                let mut metadata = MetadataMap::new();
501                metadata.insert(
502                    AsciiMetadataKey::from_static(PERSIST_PUBSUB_CALLER_KEY),
503                    AsciiMetadataValue::try_from(&config.caller_id)
504                        .unwrap_or_else(|_| AsciiMetadataValue::from_static("unknown")),
505                );
506
507                GrpcPubSubClient::reconnect_to_server_forever(
508                    send_requests,
509                    &receiver_input,
510                    pubsub_sender,
511                    metadata,
512                    config,
513                    metrics,
514                )
515                .await;
516            },
517        );
518
519        PubSubClientConnection {
520            sender,
521            receiver: Box::new(ReceiverStream::new(receiver_output)),
522        }
523    }
524}
525
526/// An internal, gRPC-backed implementation of [PubSubSender].
527struct GrpcPubSubSender {
528    metrics: Arc<Metrics>,
529    requests: tokio::sync::broadcast::Sender<ProtoPubSubMessage>,
530}
531
532impl Debug for GrpcPubSubSender {
533    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
534        let GrpcPubSubSender {
535            metrics: _metrics,
536            requests: _requests,
537        } = self;
538
539        write!(f, "GrpcPubSubSender")
540    }
541}
542
543fn create_request(message: proto_pub_sub_message::Message) -> ProtoPubSubMessage {
544    let now = SystemTime::now()
545        .duration_since(SystemTime::UNIX_EPOCH)
546        .expect("failed to get millis since epoch");
547
548    ProtoPubSubMessage {
549        timestamp: Some(now.into_proto()),
550        message: Some(message),
551    }
552}
553
554impl GrpcPubSubSender {
555    fn send(&self, message: proto_pub_sub_message::Message, metrics: &PubSubClientCallMetrics) {
556        let size = message.encoded_len();
557
558        match self.requests.send(create_request(message)) {
559            Ok(_) => {
560                metrics.succeeded.inc();
561                metrics.bytes_sent.inc_by(u64::cast_from(size));
562            }
563            Err(err) => {
564                metrics.failed.inc();
565                debug!("error sending client message: {}", err);
566            }
567        }
568    }
569}
570
571impl PubSubSenderInternal for GrpcPubSubSender {
572    fn push_diff(&self, shard_id: &ShardId, diff: &VersionedData) {
573        self.send(
574            proto_pub_sub_message::Message::PushDiff(ProtoPushDiff {
575                shard_id: shard_id.into_proto(),
576                seqno: diff.seqno.into_proto(),
577                diff: diff.data.clone(),
578            }),
579            &self.metrics.pubsub_client.sender.push,
580        )
581    }
582
583    fn subscribe(&self, shard_id: &ShardId) {
584        self.send(
585            proto_pub_sub_message::Message::Subscribe(ProtoSubscribe {
586                shard_id: shard_id.into_proto(),
587            }),
588            &self.metrics.pubsub_client.sender.subscribe,
589        )
590    }
591
592    fn unsubscribe(&self, shard_id: &ShardId) {
593        self.send(
594            proto_pub_sub_message::Message::Unsubscribe(ProtoUnsubscribe {
595                shard_id: shard_id.into_proto(),
596            }),
597            &self.metrics.pubsub_client.sender.unsubscribe,
598        )
599    }
600}
601
602/// An wrapper for a [PubSubSenderInternal] that implements [PubSubSender]
603/// by maintaining a map of active shard subscriptions to their tokens.
604#[derive(Debug)]
605struct SubscriptionTrackingSender {
606    delegate: Arc<dyn PubSubSenderInternal>,
607    subscribes: Arc<Mutex<BTreeMap<ShardId, Weak<ShardSubscriptionToken>>>>,
608}
609
610impl SubscriptionTrackingSender {
611    fn new(sender: Arc<dyn PubSubSenderInternal>) -> Self {
612        Self {
613            delegate: sender,
614            subscribes: Default::default(),
615        }
616    }
617
618    fn subscriptions(&self) -> Vec<ShardId> {
619        let mut subscribes = self.subscribes.lock().expect("lock");
620        let mut out = Vec::with_capacity(subscribes.len());
621        subscribes.retain(|shard_id, token| {
622            if token.upgrade().is_none() {
623                false
624            } else {
625                debug!("reconnecting to: {}", shard_id);
626                out.push(*shard_id);
627                true
628            }
629        });
630        out
631    }
632}
633
634impl PubSubSender for SubscriptionTrackingSender {
635    fn push_diff(&self, shard_id: &ShardId, diff: &VersionedData) {
636        self.delegate.push_diff(shard_id, diff)
637    }
638
639    fn subscribe(self: Arc<Self>, shard_id: &ShardId) -> Arc<ShardSubscriptionToken> {
640        let mut subscribes = self.subscribes.lock().expect("lock");
641        if let Some(token) = subscribes.get(shard_id) {
642            match token.upgrade() {
643                None => assert!(subscribes.remove(shard_id).is_some()),
644                Some(token) => {
645                    return Arc::clone(&token);
646                }
647            }
648        }
649
650        let pubsub_sender = Arc::clone(&self.delegate);
651        let token = Arc::new(ShardSubscriptionToken {
652            shard_id: *shard_id,
653            sender: pubsub_sender,
654        });
655
656        assert!(
657            subscribes
658                .insert(*shard_id, Arc::downgrade(&token))
659                .is_none()
660        );
661
662        self.delegate.subscribe(shard_id);
663
664        token
665    }
666}
667
668/// A wrapper intended to provide client-side metrics for a connection
669/// that communicates directly with the server state, such as one created
670/// by [PersistGrpcPubSubServer::new_same_process_connection].
671#[derive(Debug)]
672pub struct MetricsSameProcessPubSubSender {
673    delegate_subscribe: bool,
674    metrics: Arc<Metrics>,
675    delegate: Arc<dyn PubSubSender>,
676}
677
678impl MetricsSameProcessPubSubSender {
679    /// Returns a new [MetricsSameProcessPubSubSender], wrapping the given
680    /// `Arc<dyn PubSubSender>`'s calls to provide client-side metrics.
681    pub fn new(
682        cfg: &PersistConfig,
683        pubsub_sender: Arc<dyn PubSubSender>,
684        metrics: Arc<Metrics>,
685    ) -> Self {
686        Self {
687            delegate_subscribe: PUBSUB_SAME_PROCESS_DELEGATE_ENABLED.get(cfg),
688            delegate: pubsub_sender,
689            metrics,
690        }
691    }
692}
693
694impl PubSubSender for MetricsSameProcessPubSubSender {
695    fn push_diff(&self, shard_id: &ShardId, diff: &VersionedData) {
696        self.delegate.push_diff(shard_id, diff);
697        self.metrics.pubsub_client.sender.push.succeeded.inc();
698    }
699
700    fn subscribe(self: Arc<Self>, shard_id: &ShardId) -> Arc<ShardSubscriptionToken> {
701        if self.delegate_subscribe {
702            let delegate = Arc::clone(&self.delegate);
703            delegate.subscribe(shard_id)
704        } else {
705            // Create a no-op token that does not subscribe nor unsubscribe.
706            // This is ideal for single-process persist setups, since the sender and
707            // receiver should already share a state cache... but if the diffs are
708            // generated remotely but applied on the server, this may cause us to fall
709            // back to polling consensus.
710            Arc::new(ShardSubscriptionToken {
711                shard_id: *shard_id,
712                sender: Arc::new(NoopPubSubSender),
713            })
714        }
715    }
716}
717
718#[derive(Debug)]
719pub(crate) struct NoopPubSubSender;
720
721impl PubSubSenderInternal for NoopPubSubSender {
722    fn push_diff(&self, _shard_id: &ShardId, _diff: &VersionedData) {}
723    fn subscribe(&self, _shard_id: &ShardId) {}
724    fn unsubscribe(&self, _shard_id: &ShardId) {}
725}
726
727impl PubSubSender for NoopPubSubSender {
728    fn push_diff(&self, _shard_id: &ShardId, _diff: &VersionedData) {}
729
730    fn subscribe(self: Arc<Self>, shard_id: &ShardId) -> Arc<ShardSubscriptionToken> {
731        Arc::new(ShardSubscriptionToken {
732            shard_id: *shard_id,
733            sender: self,
734        })
735    }
736}
737
738/// Spawns a Tokio task that consumes a [PubSubReceiver], applying its diffs to a [StateCache].
739pub(crate) fn subscribe_state_cache_to_pubsub(
740    cache: Arc<StateCache>,
741    mut pubsub_receiver: Box<dyn PubSubReceiver>,
742) -> JoinHandle<()> {
743    let mut state_refs: HashMap<ShardId, Weak<dyn DynState>> = HashMap::new();
744    let receiver_metrics = cache.metrics.pubsub_client.receiver.clone();
745
746    mz_ore::task::spawn(
747        || "persist::rpc::client::state_cache_diff_apply",
748        async move {
749            while let Some(msg) = pubsub_receiver.next().await {
750                match msg.message {
751                    Some(proto_pub_sub_message::Message::PushDiff(diff)) => {
752                        receiver_metrics.push_received.inc();
753                        let shard_id = diff.shard_id.into_rust().expect("valid shard id");
754                        let diff = VersionedData {
755                            seqno: diff.seqno.into_rust().expect("valid SeqNo"),
756                            data: diff.diff,
757                        };
758                        debug!(
759                            "applying pubsub diff {} {} {}",
760                            shard_id,
761                            diff.seqno,
762                            diff.data.len()
763                        );
764
765                        let mut pushed_diff = false;
766                        if let Some(state_ref) = state_refs.get(&shard_id) {
767                            // common case: we have a reference to the shard state already
768                            // and can apply our diff directly.
769                            if let Some(state) = state_ref.upgrade() {
770                                state.push_diff(diff.clone());
771                                pushed_diff = true;
772                                receiver_metrics.state_pushed_diff_fast_path.inc();
773                            }
774                        }
775
776                        if !pushed_diff {
777                            // uncommon case: we either don't have a reference yet, or ours
778                            // is out-of-date (e.g. the shard was dropped and then re-added
779                            // to StateCache). here we'll fetch the latest, try to apply the
780                            // diff again, and update our local reference.
781                            let state_ref = cache.get_state_weak(&shard_id);
782                            match state_ref {
783                                None => {
784                                    state_refs.remove(&shard_id);
785                                }
786                                Some(state_ref) => {
787                                    if let Some(state) = state_ref.upgrade() {
788                                        state.push_diff(diff);
789                                        pushed_diff = true;
790                                        state_refs.insert(shard_id, state_ref);
791                                    } else {
792                                        state_refs.remove(&shard_id);
793                                    }
794                                }
795                            }
796
797                            if pushed_diff {
798                                receiver_metrics.state_pushed_diff_slow_path_succeeded.inc();
799                            } else {
800                                receiver_metrics.state_pushed_diff_slow_path_failed.inc();
801                            }
802                        }
803
804                        if let Some(send_timestamp) = msg.timestamp {
805                            let send_timestamp =
806                                send_timestamp.into_rust().expect("valid timestamp");
807                            let now = SystemTime::now()
808                                .duration_since(SystemTime::UNIX_EPOCH)
809                                .expect("failed to get millis since epoch");
810                            receiver_metrics
811                                .approx_diff_latency_seconds
812                                .observe((now.saturating_sub(send_timestamp)).as_secs_f64());
813                        }
814                    }
815                    ref msg @ None | ref msg @ Some(_) => {
816                        warn!("pubsub client received unexpected message: {:?}", msg);
817                        receiver_metrics.unknown_message_received.inc();
818                    }
819                }
820            }
821        },
822    )
823}
824
825/// Internal state of a PubSub server implementation.
826#[derive(Debug)]
827pub(crate) struct PubSubState {
828    /// Assigns a unique ID to each incoming connection.
829    connection_id_counter: AtomicUsize,
830    /// Maintains a mapping of `ShardId --> [ConnectionId -> Tx]`.
831    shard_subscribers:
832        Arc<RwLock<BTreeMap<ShardId, BTreeMap<usize, Sender<Result<ProtoPubSubMessage, Status>>>>>>,
833    /// Active connections.
834    connections: Arc<RwLock<HashSet<usize>>>,
835    /// Server-side metrics.
836    metrics: Arc<PubSubServerMetrics>,
837}
838
839impl PubSubState {
840    fn new_connection(
841        self: Arc<Self>,
842        notifier: Sender<Result<ProtoPubSubMessage, Status>>,
843    ) -> PubSubConnection {
844        let connection_id = self.connection_id_counter.fetch_add(1, Ordering::SeqCst);
845        {
846            debug!("inserting connid: {}", connection_id);
847            let mut connections = self.connections.write().expect("lock");
848            assert!(connections.insert(connection_id));
849        }
850
851        self.metrics.active_connections.inc();
852        PubSubConnection {
853            connection_id,
854            notifier,
855            state: self,
856        }
857    }
858
859    fn remove_connection(&self, connection_id: usize) {
860        let now = Instant::now();
861
862        {
863            debug!("removing connid: {}", connection_id);
864            let mut connections = self.connections.write().expect("lock");
865            assert!(
866                connections.remove(&connection_id),
867                "unknown connection id: {}",
868                connection_id
869            );
870        }
871
872        {
873            let mut subscribers = self.shard_subscribers.write().expect("lock poisoned");
874            subscribers.retain(|_shard, connections_for_shard| {
875                connections_for_shard.remove(&connection_id);
876                !connections_for_shard.is_empty()
877            });
878        }
879
880        self.metrics
881            .connection_cleanup_seconds
882            .inc_by(now.elapsed().as_secs_f64());
883        self.metrics.active_connections.dec();
884    }
885
886    fn push_diff(&self, connection_id: usize, shard_id: &ShardId, data: &VersionedData) {
887        let now = Instant::now();
888        self.metrics.push_call_count.inc();
889
890        assert!(
891            self.connections
892                .read()
893                .expect("lock")
894                .contains(&connection_id),
895            "unknown connection id: {}",
896            connection_id
897        );
898
899        let subscribers = self.shard_subscribers.read().expect("lock poisoned");
900        if let Some(subscribed_connections) = subscribers.get(shard_id) {
901            let mut num_sent = 0;
902            let mut data_size = 0;
903
904            for (subscribed_conn_id, tx) in subscribed_connections {
905                // skip sending the diff back to the original sender
906                if *subscribed_conn_id == connection_id {
907                    continue;
908                }
909                debug!(
910                    "server forwarding req to conn {}: {} {} {}",
911                    subscribed_conn_id,
912                    &shard_id,
913                    data.seqno,
914                    data.data.len()
915                );
916                let req = create_request(proto_pub_sub_message::Message::PushDiff(ProtoPushDiff {
917                    seqno: data.seqno.into_proto(),
918                    shard_id: shard_id.to_string(),
919                    diff: Bytes::clone(&data.data),
920                }));
921                data_size = req.encoded_len();
922                match tx.try_send(Ok(req)) {
923                    Ok(_) => {
924                        num_sent += 1;
925                    }
926                    Err(TrySendError::Full(_)) => {
927                        self.metrics.broadcasted_diff_dropped_channel_full.inc();
928                    }
929                    Err(TrySendError::Closed(_)) => {}
930                };
931            }
932
933            self.metrics.broadcasted_diff_count.inc_by(num_sent);
934            self.metrics
935                .broadcasted_diff_bytes
936                .inc_by(num_sent * u64::cast_from(data_size));
937        }
938
939        self.metrics
940            .push_seconds
941            .inc_by(now.elapsed().as_secs_f64());
942    }
943
944    fn subscribe(
945        &self,
946        connection_id: usize,
947        notifier: Sender<Result<ProtoPubSubMessage, Status>>,
948        shard_id: &ShardId,
949    ) {
950        let now = Instant::now();
951        self.metrics.subscribe_call_count.inc();
952
953        assert!(
954            self.connections
955                .read()
956                .expect("lock")
957                .contains(&connection_id),
958            "unknown connection id: {}",
959            connection_id
960        );
961
962        {
963            let mut subscribed_shards = self.shard_subscribers.write().expect("lock poisoned");
964            subscribed_shards
965                .entry(*shard_id)
966                .or_default()
967                .insert(connection_id, notifier);
968        }
969
970        self.metrics
971            .subscribe_seconds
972            .inc_by(now.elapsed().as_secs_f64());
973    }
974
975    fn unsubscribe(&self, connection_id: usize, shard_id: &ShardId) {
976        let now = Instant::now();
977        self.metrics.unsubscribe_call_count.inc();
978
979        assert!(
980            self.connections
981                .read()
982                .expect("lock")
983                .contains(&connection_id),
984            "unknown connection id: {}",
985            connection_id
986        );
987
988        {
989            let mut subscribed_shards = self.shard_subscribers.write().expect("lock poisoned");
990            if let Entry::Occupied(mut entry) = subscribed_shards.entry(*shard_id) {
991                let subscribed_connections = entry.get_mut();
992                subscribed_connections.remove(&connection_id);
993
994                if subscribed_connections.is_empty() {
995                    entry.remove_entry();
996                }
997            }
998        }
999
1000        self.metrics
1001            .unsubscribe_seconds
1002            .inc_by(now.elapsed().as_secs_f64());
1003    }
1004
1005    #[cfg(test)]
1006    fn new_for_test() -> Self {
1007        Self {
1008            connection_id_counter: AtomicUsize::new(0),
1009            shard_subscribers: Default::default(),
1010            connections: Default::default(),
1011            metrics: Arc::new(PubSubServerMetrics::new(&MetricsRegistry::new())),
1012        }
1013    }
1014
1015    #[cfg(test)]
1016    fn active_connections(&self) -> HashSet<usize> {
1017        self.connections.read().expect("lock").clone()
1018    }
1019
1020    #[cfg(test)]
1021    fn subscriptions(&self, connection_id: usize) -> HashSet<ShardId> {
1022        let mut shards = HashSet::new();
1023
1024        let subscribers = self.shard_subscribers.read().expect("lock");
1025        for (shard, subscribed_connections) in subscribers.iter() {
1026            if subscribed_connections.contains_key(&connection_id) {
1027                shards.insert(*shard);
1028            }
1029        }
1030
1031        shards
1032    }
1033
1034    #[cfg(test)]
1035    fn shard_subscription_counts(&self) -> mz_ore::collections::HashMap<ShardId, usize> {
1036        let mut shards = mz_ore::collections::HashMap::new();
1037
1038        let subscribers = self.shard_subscribers.read().expect("lock");
1039        for (shard, subscribed_connections) in subscribers.iter() {
1040            shards.insert(*shard, subscribed_connections.len());
1041        }
1042
1043        shards
1044    }
1045}
1046
1047/// A gRPC-based implementation of a Persist PubSub server.
1048#[derive(Debug)]
1049pub struct PersistGrpcPubSubServer {
1050    cfg: PersistConfig,
1051    state: Arc<PubSubState>,
1052}
1053
1054impl PersistGrpcPubSubServer {
1055    /// Creates a new [PersistGrpcPubSubServer].
1056    pub fn new(cfg: &PersistConfig, metrics_registry: &MetricsRegistry) -> Self {
1057        let metrics = PubSubServerMetrics::new(metrics_registry);
1058        let state = Arc::new(PubSubState {
1059            connection_id_counter: AtomicUsize::new(0),
1060            shard_subscribers: Default::default(),
1061            connections: Default::default(),
1062            metrics: Arc::new(metrics),
1063        });
1064
1065        PersistGrpcPubSubServer {
1066            cfg: cfg.clone(),
1067            state,
1068        }
1069    }
1070
1071    /// Creates a connection to [PersistGrpcPubSubServer] that is directly connected
1072    /// to the server state. Calls into this connection do not go over the network
1073    /// nor require message serde.
1074    pub fn new_same_process_connection(&self) -> PubSubClientConnection {
1075        let (tx, rx) =
1076            tokio::sync::mpsc::channel(PUBSUB_CLIENT_RECEIVER_CHANNEL_SIZE.get(&self.cfg));
1077        let sender: Arc<dyn PubSubSender> = Arc::new(SubscriptionTrackingSender::new(Arc::new(
1078            Arc::clone(&self.state).new_connection(tx),
1079        )));
1080
1081        PubSubClientConnection {
1082            sender,
1083            receiver: Box::new(
1084                ReceiverStream::new(rx).map(|x| x.expect("cannot receive grpc errors locally")),
1085            ),
1086        }
1087    }
1088
1089    /// Starts the gRPC server. Consumes `self` and runs until the task is cancelled.
1090    pub async fn serve(self, listen_addr: SocketAddr) -> Result<(), anyhow::Error> {
1091        // Increase the default message decoding limit to avoid unnecessary panics
1092        tonic::transport::Server::builder()
1093            .add_service(
1094                ProtoPersistPubSubServer::new(self)
1095                    .max_decoding_message_size(MAX_GRPC_MESSAGE_SIZE),
1096            )
1097            .serve(listen_addr)
1098            .await?;
1099        Ok(())
1100    }
1101
1102    /// Starts the gRPC server with the given listener stream.
1103    /// Consumes `self` and runs until the task is cancelled.
1104    pub async fn serve_with_stream(
1105        self,
1106        listener: tokio_stream::wrappers::TcpListenerStream,
1107    ) -> Result<(), anyhow::Error> {
1108        tonic::transport::Server::builder()
1109            .add_service(
1110                ProtoPersistPubSubServer::new(self)
1111                    .max_decoding_message_size(MAX_GRPC_MESSAGE_SIZE),
1112            )
1113            .serve_with_incoming(listener)
1114            .await?;
1115        Ok(())
1116    }
1117}
1118
1119#[async_trait]
1120impl proto_persist_pub_sub_server::ProtoPersistPubSub for PersistGrpcPubSubServer {
1121    type PubSubStream = Pin<Box<dyn Stream<Item = Result<ProtoPubSubMessage, Status>> + Send>>;
1122
1123    #[mz_ore::instrument(name = "persist::rpc::server", level = "info")]
1124    async fn pub_sub(
1125        &self,
1126        request: Request<Streaming<ProtoPubSubMessage>>,
1127    ) -> Result<Response<Self::PubSubStream>, Status> {
1128        let caller_id = request
1129            .metadata()
1130            .get(AsciiMetadataKey::from_static(PERSIST_PUBSUB_CALLER_KEY))
1131            .map(|key| key.to_str().ok())
1132            .flatten()
1133            .map(|key| key.to_string())
1134            .unwrap_or_else(|| "unknown".to_string());
1135        info!("Received Persist PubSub connection from: {:?}", caller_id);
1136
1137        let mut in_stream = request.into_inner();
1138        let (tx, rx) =
1139            tokio::sync::mpsc::channel(PUBSUB_SERVER_CONNECTION_CHANNEL_SIZE.get(&self.cfg));
1140
1141        let caller = caller_id.clone();
1142        let cfg = Arc::clone(&self.cfg.configs);
1143        let server_state = Arc::clone(&self.state);
1144        // this spawn here to cleanup after connection error / disconnect, otherwise the stream
1145        // would not be polled after the connection drops. in our case, we want to clear the
1146        // connection and its subscriptions from our shared state when it drops.
1147        let connection_span = info_span!("connection", caller_id);
1148        mz_ore::task::spawn(
1149            || format!("persist_pubsub_connection({})", caller),
1150            async move {
1151                let connection = server_state.new_connection(tx);
1152                while let Some(result) = in_stream.next().await {
1153                    let req = match result {
1154                        Ok(req) => req,
1155                        Err(err) => {
1156                            warn!("pubsub connection err: {}", err);
1157                            break;
1158                        }
1159                    };
1160
1161                    match req.message {
1162                        None => {
1163                            warn!("received empty message from: {}", caller_id);
1164                        }
1165                        Some(proto_pub_sub_message::Message::PushDiff(req)) => {
1166                            let shard_id = req.shard_id.parse().expect("valid shard id");
1167                            let diff = VersionedData {
1168                                seqno: req.seqno.into_rust().expect("valid seqno"),
1169                                data: req.diff.clone(),
1170                            };
1171                            if PUBSUB_PUSH_DIFF_ENABLED.get(&cfg) {
1172                                connection.push_diff(&shard_id, &diff);
1173                            }
1174                        }
1175                        Some(proto_pub_sub_message::Message::Subscribe(diff)) => {
1176                            let shard_id = diff.shard_id.parse().expect("valid shard id");
1177                            connection.subscribe(&shard_id);
1178                        }
1179                        Some(proto_pub_sub_message::Message::Unsubscribe(diff)) => {
1180                            let shard_id = diff.shard_id.parse().expect("valid shard id");
1181                            connection.unsubscribe(&shard_id);
1182                        }
1183                    }
1184                }
1185
1186                info!("Persist PubSub connection ended: {:?}", caller_id);
1187            }
1188            .instrument(connection_span),
1189        );
1190
1191        let out_stream: Self::PubSubStream = Box::pin(ReceiverStream::new(rx));
1192        Ok(Response::new(out_stream))
1193    }
1194}
1195
1196/// An active connection managed by [PubSubState].
1197///
1198/// When dropped, removes itself from [PubSubState], clearing all of its subscriptions.
1199#[derive(Debug)]
1200pub(crate) struct PubSubConnection {
1201    connection_id: usize,
1202    notifier: Sender<Result<ProtoPubSubMessage, Status>>,
1203    state: Arc<PubSubState>,
1204}
1205
1206impl PubSubSenderInternal for PubSubConnection {
1207    fn push_diff(&self, shard_id: &ShardId, diff: &VersionedData) {
1208        self.state.push_diff(self.connection_id, shard_id, diff)
1209    }
1210
1211    fn subscribe(&self, shard_id: &ShardId) {
1212        self.state
1213            .subscribe(self.connection_id, self.notifier.clone(), shard_id)
1214    }
1215
1216    fn unsubscribe(&self, shard_id: &ShardId) {
1217        self.state.unsubscribe(self.connection_id, shard_id)
1218    }
1219}
1220
1221impl Drop for PubSubConnection {
1222    fn drop(&mut self) {
1223        self.state.remove_connection(self.connection_id)
1224    }
1225}
1226
1227#[cfg(test)]
1228mod pubsub_state {
1229    use std::str::FromStr;
1230    use std::sync::Arc;
1231    use std::sync::LazyLock;
1232
1233    use bytes::Bytes;
1234    use mz_ore::collections::HashSet;
1235    use mz_persist::location::{SeqNo, VersionedData};
1236    use mz_proto::RustType;
1237    use tokio::sync::mpsc::Receiver;
1238    use tokio::sync::mpsc::error::TryRecvError;
1239    use tonic::Status;
1240
1241    use crate::ShardId;
1242    use crate::internal::service::ProtoPubSubMessage;
1243    use crate::internal::service::proto_pub_sub_message::Message;
1244    use crate::rpc::{PubSubSenderInternal, PubSubState};
1245
1246    static SHARD_ID_0: LazyLock<ShardId> =
1247        LazyLock::new(|| ShardId::from_str("s00000000-0000-0000-0000-000000000000").unwrap());
1248    static SHARD_ID_1: LazyLock<ShardId> =
1249        LazyLock::new(|| ShardId::from_str("s11111111-1111-1111-1111-111111111111").unwrap());
1250
1251    const VERSIONED_DATA_0: VersionedData = VersionedData {
1252        seqno: SeqNo(0),
1253        data: Bytes::from_static(&[0, 1, 2, 3]),
1254    };
1255
1256    const VERSIONED_DATA_1: VersionedData = VersionedData {
1257        seqno: SeqNo(1),
1258        data: Bytes::from_static(&[4, 5, 6, 7]),
1259    };
1260
1261    #[mz_ore::test]
1262    #[should_panic(expected = "unknown connection id: 100")]
1263    fn test_zero_connections_push_diff() {
1264        let state = Arc::new(PubSubState::new_for_test());
1265        state.push_diff(100, &SHARD_ID_0, &VERSIONED_DATA_0);
1266    }
1267
1268    #[mz_ore::test]
1269    #[should_panic(expected = "unknown connection id: 100")]
1270    fn test_zero_connections_subscribe() {
1271        let state = Arc::new(PubSubState::new_for_test());
1272        let (tx, _) = tokio::sync::mpsc::channel(100);
1273        state.subscribe(100, tx, &SHARD_ID_0);
1274    }
1275
1276    #[mz_ore::test]
1277    #[should_panic(expected = "unknown connection id: 100")]
1278    fn test_zero_connections_unsubscribe() {
1279        let state = Arc::new(PubSubState::new_for_test());
1280        state.unsubscribe(100, &SHARD_ID_0);
1281    }
1282
1283    #[mz_ore::test]
1284    #[should_panic(expected = "unknown connection id: 100")]
1285    fn test_zero_connections_remove() {
1286        let state = Arc::new(PubSubState::new_for_test());
1287        state.remove_connection(100)
1288    }
1289
1290    #[mz_ore::test]
1291    fn test_single_connection() {
1292        let state = Arc::new(PubSubState::new_for_test());
1293
1294        let (tx, mut rx) = tokio::sync::mpsc::channel(100);
1295        let connection = Arc::clone(&state).new_connection(tx);
1296
1297        assert_eq!(
1298            state.active_connections(),
1299            HashSet::from([connection.connection_id])
1300        );
1301
1302        // no messages should have been broadcasted yet
1303        assert!(matches!(rx.try_recv(), Err(TryRecvError::Empty)));
1304
1305        connection.push_diff(
1306            &SHARD_ID_0,
1307            &VersionedData {
1308                seqno: SeqNo::minimum(),
1309                data: Bytes::new(),
1310            },
1311        );
1312
1313        // server should not broadcast a message back to originating client
1314        assert!(matches!(rx.try_recv(), Err(TryRecvError::Empty)));
1315
1316        // a connection can subscribe to a shard
1317        connection.subscribe(&SHARD_ID_0);
1318        assert_eq!(
1319            state.subscriptions(connection.connection_id),
1320            HashSet::from([SHARD_ID_0.clone()])
1321        );
1322
1323        // a connection can unsubscribe
1324        connection.unsubscribe(&SHARD_ID_0);
1325        assert!(state.subscriptions(connection.connection_id).is_empty());
1326
1327        // a connection can subscribe to many shards
1328        connection.subscribe(&SHARD_ID_0);
1329        connection.subscribe(&SHARD_ID_1);
1330        assert_eq!(
1331            state.subscriptions(connection.connection_id),
1332            HashSet::from([*SHARD_ID_0, *SHARD_ID_1])
1333        );
1334
1335        // and to a single shard many times idempotently
1336        connection.subscribe(&SHARD_ID_0);
1337        connection.subscribe(&SHARD_ID_0);
1338        assert_eq!(
1339            state.subscriptions(connection.connection_id),
1340            HashSet::from([*SHARD_ID_0, *SHARD_ID_1])
1341        );
1342
1343        // dropping the connection should unsubscribe all shards and unregister the connection
1344        let connection_id = connection.connection_id;
1345        drop(connection);
1346        assert!(state.subscriptions(connection_id).is_empty());
1347        assert!(state.active_connections().is_empty());
1348    }
1349
1350    #[mz_ore::test]
1351    fn test_many_connection() {
1352        let state = Arc::new(PubSubState::new_for_test());
1353
1354        let (tx1, mut rx1) = tokio::sync::mpsc::channel(100);
1355        let conn1 = Arc::clone(&state).new_connection(tx1);
1356
1357        let (tx2, mut rx2) = tokio::sync::mpsc::channel(100);
1358        let conn2 = Arc::clone(&state).new_connection(tx2);
1359
1360        let (tx3, mut rx3) = tokio::sync::mpsc::channel(100);
1361        let conn3 = Arc::clone(&state).new_connection(tx3);
1362
1363        conn1.subscribe(&SHARD_ID_0);
1364        conn2.subscribe(&SHARD_ID_0);
1365        conn2.subscribe(&SHARD_ID_1);
1366
1367        assert_eq!(
1368            state.active_connections(),
1369            HashSet::from([
1370                conn1.connection_id,
1371                conn2.connection_id,
1372                conn3.connection_id
1373            ])
1374        );
1375
1376        // broadcast a diff to a shard subscribed to by several connections
1377        conn3.push_diff(&SHARD_ID_0, &VERSIONED_DATA_0);
1378        assert_push(&mut rx1, &SHARD_ID_0, &VERSIONED_DATA_0);
1379        assert_push(&mut rx2, &SHARD_ID_0, &VERSIONED_DATA_0);
1380        assert!(matches!(rx3.try_recv(), Err(TryRecvError::Empty)));
1381
1382        // broadcast a diff shared by publisher. it should not receive the diff back.
1383        conn1.push_diff(&SHARD_ID_0, &VERSIONED_DATA_0);
1384        assert!(matches!(rx1.try_recv(), Err(TryRecvError::Empty)));
1385        assert_push(&mut rx2, &SHARD_ID_0, &VERSIONED_DATA_0);
1386        assert!(matches!(rx3.try_recv(), Err(TryRecvError::Empty)));
1387
1388        // broadcast a diff to a shard subscribed to by one connection
1389        conn3.push_diff(&SHARD_ID_1, &VERSIONED_DATA_1);
1390        assert!(matches!(rx1.try_recv(), Err(TryRecvError::Empty)));
1391        assert_push(&mut rx2, &SHARD_ID_1, &VERSIONED_DATA_1);
1392        assert!(matches!(rx3.try_recv(), Err(TryRecvError::Empty)));
1393
1394        // broadcast a diff to a shard subscribed to by no connections
1395        conn2.unsubscribe(&SHARD_ID_1);
1396        conn3.push_diff(&SHARD_ID_1, &VERSIONED_DATA_1);
1397        assert!(matches!(rx1.try_recv(), Err(TryRecvError::Empty)));
1398        assert!(matches!(rx2.try_recv(), Err(TryRecvError::Empty)));
1399        assert!(matches!(rx3.try_recv(), Err(TryRecvError::Empty)));
1400
1401        // dropping connections unsubscribes them
1402        let conn1_id = conn1.connection_id;
1403        drop(conn1);
1404        conn3.push_diff(&SHARD_ID_0, &VERSIONED_DATA_0);
1405        assert!(matches!(rx1.try_recv(), Err(TryRecvError::Disconnected)));
1406        assert_push(&mut rx2, &SHARD_ID_0, &VERSIONED_DATA_0);
1407        assert!(matches!(rx3.try_recv(), Err(TryRecvError::Empty)));
1408
1409        assert!(state.subscriptions(conn1_id).is_empty());
1410        assert_eq!(
1411            state.subscriptions(conn2.connection_id),
1412            HashSet::from([*SHARD_ID_0])
1413        );
1414        assert_eq!(state.subscriptions(conn3.connection_id), HashSet::new());
1415        assert_eq!(
1416            state.active_connections(),
1417            HashSet::from([conn2.connection_id, conn3.connection_id])
1418        );
1419    }
1420
1421    fn assert_push(
1422        rx: &mut Receiver<Result<ProtoPubSubMessage, Status>>,
1423        shard: &ShardId,
1424        data: &VersionedData,
1425    ) {
1426        let message = rx
1427            .try_recv()
1428            .expect("message in channel")
1429            .expect("pubsub")
1430            .message
1431            .expect("proto contains message");
1432        match message {
1433            Message::PushDiff(x) => {
1434                assert_eq!(x.shard_id, shard.into_proto());
1435                assert_eq!(x.seqno, data.seqno.into_proto());
1436                assert_eq!(x.diff, data.data);
1437            }
1438            Message::Subscribe(_) | Message::Unsubscribe(_) => panic!("unexpected message type"),
1439        };
1440    }
1441}
1442
1443#[cfg(test)]
1444mod grpc {
1445    use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
1446    use std::str::FromStr;
1447    use std::sync::Arc;
1448    use std::time::{Duration, Instant};
1449
1450    use bytes::Bytes;
1451    use futures_util::FutureExt;
1452    use mz_dyncfg::ConfigUpdates;
1453    use mz_ore::assert_none;
1454    use mz_ore::collections::HashMap;
1455    use mz_ore::metrics::MetricsRegistry;
1456    use mz_persist::location::{SeqNo, VersionedData};
1457    use mz_proto::RustType;
1458    use std::sync::LazyLock;
1459    use tokio::net::TcpListener;
1460    use tokio_stream::StreamExt;
1461    use tokio_stream::wrappers::TcpListenerStream;
1462
1463    use crate::ShardId;
1464    use crate::cfg::PersistConfig;
1465    use crate::internal::service::ProtoPubSubMessage;
1466    use crate::internal::service::proto_pub_sub_message::Message;
1467    use crate::metrics::Metrics;
1468    use crate::rpc::{
1469        GrpcPubSubClient, PUBSUB_CLIENT_ENABLED, PUBSUB_RECONNECT_BACKOFF, PersistGrpcPubSubServer,
1470        PersistPubSubClient, PersistPubSubClientConfig, PubSubState,
1471    };
1472
1473    static SHARD_ID_0: LazyLock<ShardId> =
1474        LazyLock::new(|| ShardId::from_str("s00000000-0000-0000-0000-000000000000").unwrap());
1475    static SHARD_ID_1: LazyLock<ShardId> =
1476        LazyLock::new(|| ShardId::from_str("s11111111-1111-1111-1111-111111111111").unwrap());
1477    const VERSIONED_DATA_0: VersionedData = VersionedData {
1478        seqno: SeqNo(0),
1479        data: Bytes::from_static(&[0, 1, 2, 3]),
1480    };
1481    const VERSIONED_DATA_1: VersionedData = VersionedData {
1482        seqno: SeqNo(1),
1483        data: Bytes::from_static(&[4, 5, 6, 7]),
1484    };
1485
1486    const CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
1487    const SUBSCRIPTIONS_TIMEOUT: Duration = Duration::from_secs(3);
1488    const SERVER_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(2);
1489
1490    // NB: we use separate runtimes for client and server throughout these tests to cleanly drop
1491    // ALL tasks (including spawned child tasks) associated with one end of a connection, to most
1492    // closely model an actual disconnect.
1493
1494    #[mz_ore::test]
1495    #[cfg_attr(miri, ignore)] // error: unsupported operation: can't call foreign function `socket` on OS `linux`
1496    fn grpc_server() {
1497        let metrics = Arc::new(Metrics::new(
1498            &test_persist_config(),
1499            &MetricsRegistry::new(),
1500        ));
1501        let server_runtime = tokio::runtime::Runtime::new().expect("server runtime");
1502        let client_runtime = tokio::runtime::Runtime::new().expect("client runtime");
1503
1504        // start the server
1505        let (addr, tcp_listener_stream) = server_runtime.block_on(new_tcp_listener());
1506        let server_state = server_runtime.block_on(spawn_server(tcp_listener_stream));
1507
1508        // start a client.
1509        {
1510            let _guard = client_runtime.enter();
1511            mz_ore::task::spawn(|| "client".to_string(), async move {
1512                let client = GrpcPubSubClient::connect(
1513                    PersistPubSubClientConfig {
1514                        url: format!("http://{}", addr),
1515                        caller_id: "client".to_string(),
1516                        persist_cfg: test_persist_config(),
1517                    },
1518                    metrics,
1519                );
1520                let _token = client.sender.subscribe(&SHARD_ID_0);
1521                tokio::time::sleep(Duration::MAX).await;
1522            });
1523        }
1524
1525        // wait until the client is connected and subscribed
1526        server_runtime.block_on(async {
1527            poll_until_true(CONNECT_TIMEOUT, || {
1528                server_state.active_connections().len() == 1
1529            })
1530            .await;
1531            poll_until_true(SUBSCRIPTIONS_TIMEOUT, || {
1532                server_state.shard_subscription_counts() == HashMap::from([(*SHARD_ID_0, 1)])
1533            })
1534            .await
1535        });
1536
1537        // drop the client
1538        client_runtime.shutdown_timeout(SERVER_SHUTDOWN_TIMEOUT);
1539
1540        // server should notice the client dropping and clean up its state
1541        server_runtime.block_on(async {
1542            poll_until_true(CONNECT_TIMEOUT, || {
1543                server_state.active_connections().is_empty()
1544            })
1545            .await;
1546            poll_until_true(SUBSCRIPTIONS_TIMEOUT, || {
1547                server_state.shard_subscription_counts() == HashMap::new()
1548            })
1549            .await
1550        });
1551    }
1552
1553    #[mz_ore::test]
1554    #[cfg_attr(miri, ignore)] // error: unsupported operation: can't call foreign function `socket` on OS `linux`
1555    fn grpc_client_sender_reconnects() {
1556        let metrics = Arc::new(Metrics::new(
1557            &test_persist_config(),
1558            &MetricsRegistry::new(),
1559        ));
1560        let server_runtime = tokio::runtime::Runtime::new().expect("server runtime");
1561        let client_runtime = tokio::runtime::Runtime::new().expect("client runtime");
1562        let (addr, tcp_listener_stream) = server_runtime.block_on(new_tcp_listener());
1563
1564        // start a client
1565        let client = client_runtime.block_on(async {
1566            GrpcPubSubClient::connect(
1567                PersistPubSubClientConfig {
1568                    url: format!("http://{}", addr),
1569                    caller_id: "client".to_string(),
1570                    persist_cfg: test_persist_config(),
1571                },
1572                metrics,
1573            )
1574        });
1575
1576        // we can subscribe before connecting to the pubsub server
1577        let _token = Arc::clone(&client.sender).subscribe(&SHARD_ID_0);
1578        // we can subscribe and unsubscribe before connecting to the pubsub server
1579        let _token_2 = Arc::clone(&client.sender).subscribe(&SHARD_ID_1);
1580        drop(_token_2);
1581
1582        // create the server after the client is up
1583        let server_state = server_runtime.block_on(spawn_server(tcp_listener_stream));
1584
1585        server_runtime.block_on(async {
1586            // client connects automatically once the server is up
1587            poll_until_true(CONNECT_TIMEOUT, || {
1588                server_state.active_connections().len() == 1
1589            })
1590            .await;
1591
1592            // client rehydrated its subscriptions. notably, only includes the shard that
1593            // still has an active token
1594            poll_until_true(SUBSCRIPTIONS_TIMEOUT, || {
1595                server_state.shard_subscription_counts() == HashMap::from([(*SHARD_ID_0, 1)])
1596            })
1597            .await;
1598        });
1599
1600        // kill the server
1601        server_runtime.shutdown_timeout(SERVER_SHUTDOWN_TIMEOUT);
1602
1603        // client can still send requests without error
1604        let _token_2 = Arc::clone(&client.sender).subscribe(&SHARD_ID_1);
1605
1606        // create a new server
1607        let server_runtime = tokio::runtime::Runtime::new().expect("server runtime");
1608        let tcp_listener_stream = server_runtime.block_on(async {
1609            TcpListenerStream::new(
1610                TcpListener::bind(addr)
1611                    .await
1612                    .expect("can bind to previous addr"),
1613            )
1614        });
1615        let server_state = server_runtime.block_on(spawn_server(tcp_listener_stream));
1616
1617        server_runtime.block_on(async {
1618            // client automatically reconnects to new server
1619            poll_until_true(CONNECT_TIMEOUT, || {
1620                server_state.active_connections().len() == 1
1621            })
1622            .await;
1623
1624            // and rehydrates its subscriptions, including the new one that was sent
1625            // while the server was unavailable.
1626            poll_until_true(SUBSCRIPTIONS_TIMEOUT, || {
1627                server_state.shard_subscription_counts()
1628                    == HashMap::from([(*SHARD_ID_0, 1), (*SHARD_ID_1, 1)])
1629            })
1630            .await;
1631        });
1632    }
1633
1634    #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
1635    #[cfg_attr(miri, ignore)] // error: unsupported operation: can't call foreign function `socket` on OS `linux`
1636    async fn grpc_client_sender_subscription_tokens() {
1637        let metrics = Arc::new(Metrics::new(
1638            &test_persist_config(),
1639            &MetricsRegistry::new(),
1640        ));
1641
1642        let (addr, tcp_listener_stream) = new_tcp_listener().await;
1643        let server_state = spawn_server(tcp_listener_stream).await;
1644
1645        let client = GrpcPubSubClient::connect(
1646            PersistPubSubClientConfig {
1647                url: format!("http://{}", addr),
1648                caller_id: "client".to_string(),
1649                persist_cfg: test_persist_config(),
1650            },
1651            metrics,
1652        );
1653
1654        // our client connects
1655        poll_until_true(CONNECT_TIMEOUT, || {
1656            server_state.active_connections().len() == 1
1657        })
1658        .await;
1659
1660        // we can subscribe to a shard, receiving back a token
1661        let token = Arc::clone(&client.sender).subscribe(&SHARD_ID_0);
1662        poll_until_true(SUBSCRIPTIONS_TIMEOUT, || {
1663            server_state.shard_subscription_counts() == HashMap::from([(*SHARD_ID_0, 1)])
1664        })
1665        .await;
1666
1667        // dropping the token will unsubscribe our client
1668        drop(token);
1669        poll_until_true(SUBSCRIPTIONS_TIMEOUT, || {
1670            server_state.shard_subscription_counts() == HashMap::new()
1671        })
1672        .await;
1673
1674        // we can resubscribe to a shard
1675        let token = Arc::clone(&client.sender).subscribe(&SHARD_ID_0);
1676        poll_until_true(SUBSCRIPTIONS_TIMEOUT, || {
1677            server_state.shard_subscription_counts() == HashMap::from([(*SHARD_ID_0, 1)])
1678        })
1679        .await;
1680
1681        // we can subscribe many times idempotently, receiving back Arcs to the same token
1682        let token2 = Arc::clone(&client.sender).subscribe(&SHARD_ID_0);
1683        let token3 = Arc::clone(&client.sender).subscribe(&SHARD_ID_0);
1684        assert_eq!(Arc::strong_count(&token), 3);
1685        poll_until_true(SUBSCRIPTIONS_TIMEOUT, || {
1686            server_state.shard_subscription_counts() == HashMap::from([(*SHARD_ID_0, 1)])
1687        })
1688        .await;
1689
1690        // dropping all of the tokens will unsubscribe the shard
1691        drop(token);
1692        drop(token2);
1693        drop(token3);
1694        poll_until_true(SUBSCRIPTIONS_TIMEOUT, || {
1695            server_state.shard_subscription_counts() == HashMap::new()
1696        })
1697        .await;
1698
1699        // we can subscribe to many shards
1700        let _token0 = Arc::clone(&client.sender).subscribe(&SHARD_ID_0);
1701        let _token1 = Arc::clone(&client.sender).subscribe(&SHARD_ID_1);
1702        poll_until_true(SUBSCRIPTIONS_TIMEOUT, || {
1703            server_state.shard_subscription_counts()
1704                == HashMap::from([(*SHARD_ID_0, 1), (*SHARD_ID_1, 1)])
1705        })
1706        .await;
1707    }
1708
1709    #[mz_ore::test]
1710    #[cfg_attr(miri, ignore)] // error: unsupported operation: can't call foreign function `socket` on OS `linux`
1711    fn grpc_client_receiver() {
1712        let metrics = Arc::new(Metrics::new(
1713            &PersistConfig::new_for_tests(),
1714            &MetricsRegistry::new(),
1715        ));
1716        let server_runtime = tokio::runtime::Runtime::new().expect("server runtime");
1717        let client_runtime = tokio::runtime::Runtime::new().expect("client runtime");
1718        let (addr, tcp_listener_stream) = server_runtime.block_on(new_tcp_listener());
1719
1720        // create two clients, so we can test that broadcast messages are received by the other
1721        let mut client_1 = client_runtime.block_on(async {
1722            GrpcPubSubClient::connect(
1723                PersistPubSubClientConfig {
1724                    url: format!("http://{}", addr),
1725                    caller_id: "client_1".to_string(),
1726                    persist_cfg: test_persist_config(),
1727                },
1728                Arc::clone(&metrics),
1729            )
1730        });
1731        let mut client_2 = client_runtime.block_on(async {
1732            GrpcPubSubClient::connect(
1733                PersistPubSubClientConfig {
1734                    url: format!("http://{}", addr),
1735                    caller_id: "client_2".to_string(),
1736                    persist_cfg: test_persist_config(),
1737                },
1738                metrics,
1739            )
1740        });
1741
1742        // we can check our receiver output before connecting to the server.
1743        // these calls are race-y, since there's no guarantee on the time it
1744        // would take for a message to be received were one to have been sent,
1745        // but, better than nothing?
1746        assert_none!(client_1.receiver.next().now_or_never());
1747        assert_none!(client_2.receiver.next().now_or_never());
1748
1749        // start the server
1750        let server_state = server_runtime.block_on(spawn_server(tcp_listener_stream));
1751
1752        // wait until both clients are connected
1753        server_runtime.block_on(poll_until_true(CONNECT_TIMEOUT, || {
1754            server_state.active_connections().len() == 2
1755        }));
1756
1757        // no messages have been broadcast yet
1758        assert_none!(client_1.receiver.next().now_or_never());
1759        assert_none!(client_2.receiver.next().now_or_never());
1760
1761        // subscribe and send a diff
1762        let _token_client_1 = Arc::clone(&client_1.sender).subscribe(&SHARD_ID_0);
1763        let _token_client_2 = Arc::clone(&client_2.sender).subscribe(&SHARD_ID_0);
1764        server_runtime.block_on(poll_until_true(SUBSCRIPTIONS_TIMEOUT, || {
1765            server_state.shard_subscription_counts() == HashMap::from([(*SHARD_ID_0, 2)])
1766        }));
1767
1768        // the subscriber non-sender client receives the diff
1769        client_1.sender.push_diff(&SHARD_ID_0, &VERSIONED_DATA_1);
1770        assert_none!(client_1.receiver.next().now_or_never());
1771        client_runtime.block_on(async {
1772            assert_push(
1773                client_2.receiver.next().await.expect("has diff"),
1774                &SHARD_ID_0,
1775                &VERSIONED_DATA_1,
1776            )
1777        });
1778
1779        // kill the server
1780        server_runtime.shutdown_timeout(SERVER_SHUTDOWN_TIMEOUT);
1781
1782        // receivers can still be polled without error
1783        assert_none!(client_1.receiver.next().now_or_never());
1784        assert_none!(client_2.receiver.next().now_or_never());
1785
1786        // create a new server
1787        let server_runtime = tokio::runtime::Runtime::new().expect("server runtime");
1788        let tcp_listener_stream = server_runtime.block_on(async {
1789            TcpListenerStream::new(
1790                TcpListener::bind(addr)
1791                    .await
1792                    .expect("can bind to previous addr"),
1793            )
1794        });
1795        let server_state = server_runtime.block_on(spawn_server(tcp_listener_stream));
1796
1797        // client automatically reconnects to new server and rehydrates subscriptions
1798        server_runtime.block_on(async {
1799            poll_until_true(CONNECT_TIMEOUT, || {
1800                server_state.active_connections().len() == 2
1801            })
1802            .await;
1803            poll_until_true(SUBSCRIPTIONS_TIMEOUT, || {
1804                server_state.shard_subscription_counts() == HashMap::from([(*SHARD_ID_0, 2)])
1805            })
1806            .await;
1807        });
1808
1809        // pushing and receiving diffs works as expected.
1810        // this time we'll push from the other client.
1811        client_2.sender.push_diff(&SHARD_ID_0, &VERSIONED_DATA_0);
1812        client_runtime.block_on(async {
1813            assert_push(
1814                client_1.receiver.next().await.expect("has diff"),
1815                &SHARD_ID_0,
1816                &VERSIONED_DATA_0,
1817            )
1818        });
1819        assert_none!(client_2.receiver.next().now_or_never());
1820    }
1821
1822    async fn new_tcp_listener() -> (SocketAddr, TcpListenerStream) {
1823        let addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0));
1824        let tcp_listener = TcpListener::bind(addr).await.expect("tcp listener");
1825
1826        (
1827            tcp_listener.local_addr().expect("bound to local address"),
1828            TcpListenerStream::new(tcp_listener),
1829        )
1830    }
1831
1832    #[allow(clippy::unused_async)]
1833    async fn spawn_server(tcp_listener_stream: TcpListenerStream) -> Arc<PubSubState> {
1834        let server = PersistGrpcPubSubServer::new(&test_persist_config(), &MetricsRegistry::new());
1835        let server_state = Arc::clone(&server.state);
1836
1837        let _server_task = mz_ore::task::spawn(|| "server".to_string(), async move {
1838            server.serve_with_stream(tcp_listener_stream).await
1839        });
1840        server_state
1841    }
1842
1843    async fn poll_until_true<F>(timeout: Duration, f: F)
1844    where
1845        F: Fn() -> bool,
1846    {
1847        let now = Instant::now();
1848        loop {
1849            if f() {
1850                return;
1851            }
1852
1853            if now.elapsed() > timeout {
1854                panic!("timed out");
1855            }
1856
1857            tokio::time::sleep(Duration::from_millis(1)).await;
1858        }
1859    }
1860
1861    fn assert_push(message: ProtoPubSubMessage, shard: &ShardId, data: &VersionedData) {
1862        let message = message.message.expect("proto contains message");
1863        match message {
1864            Message::PushDiff(x) => {
1865                assert_eq!(x.shard_id, shard.into_proto());
1866                assert_eq!(x.seqno, data.seqno.into_proto());
1867                assert_eq!(x.diff, data.data);
1868            }
1869            Message::Subscribe(_) | Message::Unsubscribe(_) => panic!("unexpected message type"),
1870        };
1871    }
1872
1873    fn test_persist_config() -> PersistConfig {
1874        let cfg = PersistConfig::new_for_tests();
1875
1876        let mut updates = ConfigUpdates::default();
1877        updates.add(&PUBSUB_CLIENT_ENABLED, true);
1878        updates.add(&PUBSUB_RECONNECT_BACKOFF, Duration::ZERO);
1879        cfg.apply_from(&updates);
1880
1881        cfg
1882    }
1883}