1use std::collections::BTreeMap;
13use std::collections::btree_map::Entry;
14use std::fmt::{Debug, Formatter};
15use std::net::SocketAddr;
16use std::pin::Pin;
17use std::str::FromStr;
18use std::sync::atomic::{AtomicUsize, Ordering};
19use std::sync::{Arc, Mutex, RwLock, Weak};
20use std::time::{Duration, Instant, SystemTime};
21
22use anyhow::{Error, anyhow};
23use async_trait::async_trait;
24use bytes::Bytes;
25use futures::Stream;
26use futures_util::StreamExt;
27use mz_dyncfg::Config;
28use mz_ore::cast::CastFrom;
29use mz_ore::collections::{HashMap, HashSet};
30use mz_ore::metrics::MetricsRegistry;
31use mz_ore::retry::RetryResult;
32use mz_ore::task::JoinHandle;
33use mz_persist::location::VersionedData;
34use mz_proto::{ProtoType, RustType};
35use prost::Message;
36use tokio::sync::mpsc::Sender;
37use tokio::sync::mpsc::error::TrySendError;
38use tokio_stream::wrappers::errors::BroadcastStreamRecvError;
39use tokio_stream::wrappers::{BroadcastStream, ReceiverStream};
40use tonic::metadata::{AsciiMetadataKey, AsciiMetadataValue, MetadataMap};
41use tonic::transport::Endpoint;
42use tonic::{Extensions, Request, Response, Status, Streaming};
43use tracing::{Instrument, debug, error, info, info_span, warn};
44
45use crate::ShardId;
46use crate::cache::{DynState, StateCache};
47use crate::cfg::PersistConfig;
48use crate::internal::metrics::{PubSubClientCallMetrics, PubSubServerMetrics};
49use crate::internal::service::proto_persist_pub_sub_client::ProtoPersistPubSubClient;
50use crate::internal::service::proto_persist_pub_sub_server::ProtoPersistPubSubServer;
51use crate::internal::service::{
52 ProtoPubSubMessage, ProtoPushDiff, ProtoSubscribe, ProtoUnsubscribe,
53 proto_persist_pub_sub_server, proto_pub_sub_message,
54};
55use crate::metrics::Metrics;
56
57pub(crate) const PUBSUB_CLIENT_ENABLED: Config<bool> = Config::new(
59 "persist_pubsub_client_enabled",
60 true,
61 "Whether to connect to the Persist PubSub service.",
62);
63
64pub(crate) const PUBSUB_PUSH_DIFF_ENABLED: Config<bool> = Config::new(
68 "persist_pubsub_push_diff_enabled",
69 true,
70 "Whether to push state diffs to Persist PubSub.",
71);
72
73pub(crate) const PUBSUB_SAME_PROCESS_DELEGATE_ENABLED: Config<bool> = Config::new(
77 "persist_pubsub_same_process_delegate_enabled",
78 true,
79 "Whether to push state diffs to Persist PubSub on the same process.",
80);
81
82pub(crate) const PUBSUB_CONNECT_ATTEMPT_TIMEOUT: Config<Duration> = Config::new(
84 "persist_pubsub_connect_attempt_timeout",
85 Duration::from_secs(5),
86 "Timeout per connection attempt to Persist PubSub service.",
87);
88
89pub(crate) const PUBSUB_REQUEST_TIMEOUT: Config<Duration> = Config::new(
91 "persist_pubsub_request_timeout",
92 Duration::from_secs(5),
93 "Timeout per request attempt to Persist PubSub service.",
94);
95
96pub(crate) const PUBSUB_CONNECT_MAX_BACKOFF: Config<Duration> = Config::new(
98 "persist_pubsub_connect_max_backoff",
99 Duration::from_secs(60),
100 "Maximum backoff when retrying connection establishment to Persist PubSub service.",
101);
102
103pub(crate) const PUBSUB_CLIENT_SENDER_CHANNEL_SIZE: Config<usize> = Config::new(
105 "persist_pubsub_client_sender_channel_size",
106 25,
107 "Size of channel used to buffer send messages to PubSub service.",
108);
109
110pub(crate) const PUBSUB_CLIENT_RECEIVER_CHANNEL_SIZE: Config<usize> = Config::new(
112 "persist_pubsub_client_receiver_channel_size",
113 25,
114 "Size of channel used to buffer received messages from PubSub service.",
115);
116
117pub(crate) const PUBSUB_SERVER_CONNECTION_CHANNEL_SIZE: Config<usize> = Config::new(
119 "persist_pubsub_server_connection_channel_size",
120 25,
121 "Size of channel used per connection to buffer broadcasted messages from PubSub server.",
122);
123
124pub(crate) const PUBSUB_STATE_CACHE_SHARD_REF_CHANNEL_SIZE: Config<usize> = Config::new(
126 "persist_pubsub_state_cache_shard_ref_channel_size",
127 25,
128 "Size of channel used by the state cache to broadcast shard state references.",
129);
130
131pub(crate) const PUBSUB_RECONNECT_BACKOFF: Config<Duration> = Config::new(
133 "persist_pubsub_reconnect_backoff",
134 Duration::from_secs(5),
135 "Backoff after an established connection to Persist PubSub service fails.",
136);
137
138const MAX_GRPC_MESSAGE_SIZE: usize = usize::MAX;
143
144pub trait PersistPubSubClient {
149 fn connect(
151 pubsub_config: PersistPubSubClientConfig,
152 metrics: Arc<Metrics>,
153 ) -> PubSubClientConnection;
154}
155
156#[derive(Debug)]
158pub struct PubSubClientConnection {
159 pub sender: Arc<dyn PubSubSender>,
161 pub receiver: Box<dyn PubSubReceiver>,
163}
164
165impl PubSubClientConnection {
166 pub fn new(sender: Arc<dyn PubSubSender>, receiver: Box<dyn PubSubReceiver>) -> Self {
168 Self { sender, receiver }
169 }
170
171 pub fn noop() -> Self {
173 Self {
174 sender: Arc::new(NoopPubSubSender),
175 receiver: Box::new(futures::stream::empty()),
176 }
177 }
178}
179
180pub trait PubSubSender: std::fmt::Debug + Send + Sync {
182 fn push_diff(&self, shard_id: &ShardId, diff: &VersionedData);
184
185 fn subscribe(self: Arc<Self>, shard_id: &ShardId) -> Arc<ShardSubscriptionToken>;
193}
194
195trait PubSubSenderInternal: Debug + Send + Sync {
200 fn push_diff(&self, shard_id: &ShardId, diff: &VersionedData);
202
203 fn subscribe(&self, shard_id: &ShardId);
207
208 fn unsubscribe(&self, shard_id: &ShardId);
212}
213
214pub trait PubSubReceiver:
219 Stream<Item = ProtoPubSubMessage> + Send + Unpin + std::fmt::Debug
220{
221}
222
223impl<T> PubSubReceiver for T where
224 T: Stream<Item = ProtoPubSubMessage> + Send + Unpin + std::fmt::Debug
225{
226}
227
228pub struct ShardSubscriptionToken {
233 pub(crate) shard_id: ShardId,
234 sender: Arc<dyn PubSubSenderInternal>,
235}
236
237impl Debug for ShardSubscriptionToken {
238 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
239 let ShardSubscriptionToken {
240 shard_id,
241 sender: _sender,
242 } = self;
243 write!(f, "ShardSubscriptionToken({})", shard_id)
244 }
245}
246
247impl Drop for ShardSubscriptionToken {
248 fn drop(&mut self) {
249 self.sender.unsubscribe(&self.shard_id);
250 }
251}
252
253pub const PERSIST_PUBSUB_CALLER_KEY: &str = "persist-pubsub-caller-id";
255
256#[derive(Debug)]
258pub struct PersistPubSubClientConfig {
259 pub url: String,
261 pub caller_id: String,
263 pub persist_cfg: PersistConfig,
265}
266
267#[derive(Debug)]
273pub struct GrpcPubSubClient;
274
275impl GrpcPubSubClient {
276 async fn reconnect_to_server_forever(
277 send_requests: tokio::sync::broadcast::Sender<ProtoPubSubMessage>,
278 receiver_input: &tokio::sync::mpsc::Sender<ProtoPubSubMessage>,
279 sender: Arc<SubscriptionTrackingSender>,
280 metadata: MetadataMap,
281 config: PersistPubSubClientConfig,
282 metrics: Arc<Metrics>,
283 ) {
284 config.persist_cfg.configs_synced_once().await;
289
290 let mut is_first_connection_attempt = true;
291 loop {
292 let sender = Arc::clone(&sender);
293 metrics.pubsub_client.grpc_connection.connected.set(0);
294
295 if !PUBSUB_CLIENT_ENABLED.get(&config.persist_cfg) {
296 tokio::time::sleep(Duration::from_secs(5)).await;
297 continue;
298 }
299
300 if is_first_connection_attempt {
302 is_first_connection_attempt = false;
303 } else {
304 tokio::time::sleep(PUBSUB_RECONNECT_BACKOFF.get(&config.persist_cfg)).await;
305 }
306
307 info!("Connecting to Persist PubSub: {}", config.url);
308 let client = mz_ore::retry::Retry::default()
309 .clamp_backoff(PUBSUB_CONNECT_MAX_BACKOFF.get(&config.persist_cfg))
310 .retry_async(|_| async {
311 metrics
312 .pubsub_client
313 .grpc_connection
314 .connect_call_attempt_count
315 .inc();
316 let endpoint = match Endpoint::from_str(&config.url) {
317 Ok(endpoint) => endpoint,
318 Err(err) => return RetryResult::FatalErr(err),
319 };
320 ProtoPersistPubSubClient::connect(
321 endpoint
322 .connect_timeout(
323 PUBSUB_CONNECT_ATTEMPT_TIMEOUT.get(&config.persist_cfg),
324 )
325 .timeout(PUBSUB_REQUEST_TIMEOUT.get(&config.persist_cfg)),
326 )
327 .await
328 .into()
329 })
330 .await;
331
332 let mut client = match client {
333 Ok(client) => client.max_decoding_message_size(MAX_GRPC_MESSAGE_SIZE),
334 Err(err) => {
335 error!("fatal error connecting to persist pubsub: {:?}", err);
336 return;
337 }
338 };
339
340 metrics
341 .pubsub_client
342 .grpc_connection
343 .connection_established_count
344 .inc();
345 metrics.pubsub_client.grpc_connection.connected.set(1);
346
347 info!("Connected to Persist PubSub: {}", config.url);
348
349 let mut broadcast = BroadcastStream::new(send_requests.subscribe());
350 let broadcast_errors = metrics
351 .pubsub_client
352 .grpc_connection
353 .broadcast_recv_lagged_count
354 .clone();
355
356 let (cancel_tx, cancel_rx) = tokio::sync::oneshot::channel::<()>();
362
363 let broadcast_messages = async_stream::stream! {
367 let mut cancel_rx = std::pin::pin!(cancel_rx);
368 'reconnect: loop {
369 for id in sender.subscriptions() {
371 debug!("re-subscribing to shard: {id}");
372 let msg = proto_pub_sub_message::Message::Subscribe(
373 ProtoSubscribe {
374 shard_id: id.into_proto(),
375 },
376 );
377 yield create_request(msg);
378 }
379
380 loop {
382 tokio::select! {
383 message = broadcast.next() => {
384 debug!("sending pubsub message: {:?}", message);
385 match message {
386 Some(Ok(message)) => yield message,
387 Some(Err(BroadcastStreamRecvError::Lagged(i))) => {
388 broadcast_errors.inc_by(i);
389 continue 'reconnect;
390 }
391 None => {
392 debug!("exhausted pubsub broadcast stream; shutting down");
393 return;
394 }
395 }
396 }
397 _ = &mut cancel_rx => {
398 debug!("pubsub broadcast stream cancelled; shutting down");
399 return;
400 }
401 }
402 }
403 }
404 };
405 let pubsub_request =
406 Request::from_parts(metadata.clone(), Extensions::default(), broadcast_messages);
407
408 let responses = match client.pub_sub(pubsub_request).await {
409 Ok(response) => response.into_inner(),
410 Err(err) => {
411 warn!("pub_sub rpc error: {:?}", err);
412 continue;
413 }
414 };
415
416 let stream_completed = GrpcPubSubClient::consume_grpc_stream(
417 responses,
418 receiver_input,
419 &config,
420 metrics.as_ref(),
421 )
422 .await;
423
424 drop(cancel_tx);
425
426 match stream_completed {
427 Ok(_) => continue,
429 Err(err) => {
432 warn!("shutting down connection loop to Persist PubSub: {}", err);
433 return;
434 }
435 }
436 }
437 }
438
439 async fn consume_grpc_stream(
440 mut responses: Streaming<ProtoPubSubMessage>,
441 receiver_input: &Sender<ProtoPubSubMessage>,
442 config: &PersistPubSubClientConfig,
443 metrics: &Metrics,
444 ) -> Result<(), Error> {
445 loop {
446 if !PUBSUB_CLIENT_ENABLED.get(&config.persist_cfg) {
447 return Ok(());
448 }
449
450 debug!("awaiting next pubsub response");
451 match responses.next().await {
452 Some(Ok(message)) => {
453 debug!("received pubsub message: {:?}", message);
454 match receiver_input.send(message).await {
455 Ok(_) => {}
456 Err(err) => {
459 return Err(anyhow!("closing pubsub grpc client connection: {}", err));
460 }
461 }
462 }
463 Some(Err(err)) => {
464 metrics.pubsub_client.grpc_connection.grpc_error_count.inc();
465 warn!("pubsub client error: {:?}", err);
466 return Ok(());
467 }
468 None => return Ok(()),
469 }
470 }
471 }
472}
473
474impl PersistPubSubClient for GrpcPubSubClient {
475 fn connect(config: PersistPubSubClientConfig, metrics: Arc<Metrics>) -> PubSubClientConnection {
476 let (send_requests, _) = tokio::sync::broadcast::channel(
481 PUBSUB_CLIENT_SENDER_CHANNEL_SIZE.get(&config.persist_cfg),
482 );
483 let (receiver_input, receiver_output) = tokio::sync::mpsc::channel(
487 PUBSUB_CLIENT_RECEIVER_CHANNEL_SIZE.get(&config.persist_cfg),
488 );
489
490 let sender = Arc::new(SubscriptionTrackingSender::new(Arc::new(
491 GrpcPubSubSender {
492 metrics: Arc::clone(&metrics),
493 requests: send_requests.clone(),
494 },
495 )));
496 let pubsub_sender = Arc::clone(&sender);
497 mz_ore::task::spawn(
498 || "persist::rpc::client::connection".to_string(),
499 async move {
500 let mut metadata = MetadataMap::new();
501 metadata.insert(
502 AsciiMetadataKey::from_static(PERSIST_PUBSUB_CALLER_KEY),
503 AsciiMetadataValue::try_from(&config.caller_id)
504 .unwrap_or_else(|_| AsciiMetadataValue::from_static("unknown")),
505 );
506
507 GrpcPubSubClient::reconnect_to_server_forever(
508 send_requests,
509 &receiver_input,
510 pubsub_sender,
511 metadata,
512 config,
513 metrics,
514 )
515 .await;
516 },
517 );
518
519 PubSubClientConnection {
520 sender,
521 receiver: Box::new(ReceiverStream::new(receiver_output)),
522 }
523 }
524}
525
526struct GrpcPubSubSender {
528 metrics: Arc<Metrics>,
529 requests: tokio::sync::broadcast::Sender<ProtoPubSubMessage>,
530}
531
532impl Debug for GrpcPubSubSender {
533 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
534 let GrpcPubSubSender {
535 metrics: _metrics,
536 requests: _requests,
537 } = self;
538
539 write!(f, "GrpcPubSubSender")
540 }
541}
542
543fn create_request(message: proto_pub_sub_message::Message) -> ProtoPubSubMessage {
544 let now = SystemTime::now()
545 .duration_since(SystemTime::UNIX_EPOCH)
546 .expect("failed to get millis since epoch");
547
548 ProtoPubSubMessage {
549 timestamp: Some(now.into_proto()),
550 message: Some(message),
551 }
552}
553
554impl GrpcPubSubSender {
555 fn send(&self, message: proto_pub_sub_message::Message, metrics: &PubSubClientCallMetrics) {
556 let size = message.encoded_len();
557
558 match self.requests.send(create_request(message)) {
559 Ok(_) => {
560 metrics.succeeded.inc();
561 metrics.bytes_sent.inc_by(u64::cast_from(size));
562 }
563 Err(err) => {
564 metrics.failed.inc();
565 debug!("error sending client message: {}", err);
566 }
567 }
568 }
569}
570
571impl PubSubSenderInternal for GrpcPubSubSender {
572 fn push_diff(&self, shard_id: &ShardId, diff: &VersionedData) {
573 self.send(
574 proto_pub_sub_message::Message::PushDiff(ProtoPushDiff {
575 shard_id: shard_id.into_proto(),
576 seqno: diff.seqno.into_proto(),
577 diff: diff.data.clone(),
578 }),
579 &self.metrics.pubsub_client.sender.push,
580 )
581 }
582
583 fn subscribe(&self, shard_id: &ShardId) {
584 self.send(
585 proto_pub_sub_message::Message::Subscribe(ProtoSubscribe {
586 shard_id: shard_id.into_proto(),
587 }),
588 &self.metrics.pubsub_client.sender.subscribe,
589 )
590 }
591
592 fn unsubscribe(&self, shard_id: &ShardId) {
593 self.send(
594 proto_pub_sub_message::Message::Unsubscribe(ProtoUnsubscribe {
595 shard_id: shard_id.into_proto(),
596 }),
597 &self.metrics.pubsub_client.sender.unsubscribe,
598 )
599 }
600}
601
602#[derive(Debug)]
605struct SubscriptionTrackingSender {
606 delegate: Arc<dyn PubSubSenderInternal>,
607 subscribes: Arc<Mutex<BTreeMap<ShardId, Weak<ShardSubscriptionToken>>>>,
608}
609
610impl SubscriptionTrackingSender {
611 fn new(sender: Arc<dyn PubSubSenderInternal>) -> Self {
612 Self {
613 delegate: sender,
614 subscribes: Default::default(),
615 }
616 }
617
618 fn subscriptions(&self) -> Vec<ShardId> {
619 let mut subscribes = self.subscribes.lock().expect("lock");
620 let mut out = Vec::with_capacity(subscribes.len());
621 subscribes.retain(|shard_id, token| {
622 if token.upgrade().is_none() {
623 false
624 } else {
625 debug!("reconnecting to: {}", shard_id);
626 out.push(*shard_id);
627 true
628 }
629 });
630 out
631 }
632}
633
634impl PubSubSender for SubscriptionTrackingSender {
635 fn push_diff(&self, shard_id: &ShardId, diff: &VersionedData) {
636 self.delegate.push_diff(shard_id, diff)
637 }
638
639 fn subscribe(self: Arc<Self>, shard_id: &ShardId) -> Arc<ShardSubscriptionToken> {
640 let mut subscribes = self.subscribes.lock().expect("lock");
641 if let Some(token) = subscribes.get(shard_id) {
642 match token.upgrade() {
643 None => assert!(subscribes.remove(shard_id).is_some()),
644 Some(token) => {
645 return Arc::clone(&token);
646 }
647 }
648 }
649
650 let pubsub_sender = Arc::clone(&self.delegate);
651 let token = Arc::new(ShardSubscriptionToken {
652 shard_id: *shard_id,
653 sender: pubsub_sender,
654 });
655
656 assert!(
657 subscribes
658 .insert(*shard_id, Arc::downgrade(&token))
659 .is_none()
660 );
661
662 self.delegate.subscribe(shard_id);
663
664 token
665 }
666}
667
668#[derive(Debug)]
672pub struct MetricsSameProcessPubSubSender {
673 delegate_subscribe: bool,
674 metrics: Arc<Metrics>,
675 delegate: Arc<dyn PubSubSender>,
676}
677
678impl MetricsSameProcessPubSubSender {
679 pub fn new(
682 cfg: &PersistConfig,
683 pubsub_sender: Arc<dyn PubSubSender>,
684 metrics: Arc<Metrics>,
685 ) -> Self {
686 Self {
687 delegate_subscribe: PUBSUB_SAME_PROCESS_DELEGATE_ENABLED.get(cfg),
688 delegate: pubsub_sender,
689 metrics,
690 }
691 }
692}
693
694impl PubSubSender for MetricsSameProcessPubSubSender {
695 fn push_diff(&self, shard_id: &ShardId, diff: &VersionedData) {
696 self.delegate.push_diff(shard_id, diff);
697 self.metrics.pubsub_client.sender.push.succeeded.inc();
698 }
699
700 fn subscribe(self: Arc<Self>, shard_id: &ShardId) -> Arc<ShardSubscriptionToken> {
701 if self.delegate_subscribe {
702 let delegate = Arc::clone(&self.delegate);
703 delegate.subscribe(shard_id)
704 } else {
705 Arc::new(ShardSubscriptionToken {
711 shard_id: *shard_id,
712 sender: Arc::new(NoopPubSubSender),
713 })
714 }
715 }
716}
717
718#[derive(Debug)]
719pub(crate) struct NoopPubSubSender;
720
721impl PubSubSenderInternal for NoopPubSubSender {
722 fn push_diff(&self, _shard_id: &ShardId, _diff: &VersionedData) {}
723 fn subscribe(&self, _shard_id: &ShardId) {}
724 fn unsubscribe(&self, _shard_id: &ShardId) {}
725}
726
727impl PubSubSender for NoopPubSubSender {
728 fn push_diff(&self, _shard_id: &ShardId, _diff: &VersionedData) {}
729
730 fn subscribe(self: Arc<Self>, shard_id: &ShardId) -> Arc<ShardSubscriptionToken> {
731 Arc::new(ShardSubscriptionToken {
732 shard_id: *shard_id,
733 sender: self,
734 })
735 }
736}
737
738pub(crate) fn subscribe_state_cache_to_pubsub(
740 cache: Arc<StateCache>,
741 mut pubsub_receiver: Box<dyn PubSubReceiver>,
742) -> JoinHandle<()> {
743 let mut state_refs: HashMap<ShardId, Weak<dyn DynState>> = HashMap::new();
744 let receiver_metrics = cache.metrics.pubsub_client.receiver.clone();
745
746 mz_ore::task::spawn(
747 || "persist::rpc::client::state_cache_diff_apply",
748 async move {
749 while let Some(msg) = pubsub_receiver.next().await {
750 match msg.message {
751 Some(proto_pub_sub_message::Message::PushDiff(diff)) => {
752 receiver_metrics.push_received.inc();
753 let shard_id = diff.shard_id.into_rust().expect("valid shard id");
754 let diff = VersionedData {
755 seqno: diff.seqno.into_rust().expect("valid SeqNo"),
756 data: diff.diff,
757 };
758 debug!(
759 "applying pubsub diff {} {} {}",
760 shard_id,
761 diff.seqno,
762 diff.data.len()
763 );
764
765 let mut pushed_diff = false;
766 if let Some(state_ref) = state_refs.get(&shard_id) {
767 if let Some(state) = state_ref.upgrade() {
770 state.push_diff(diff.clone());
771 pushed_diff = true;
772 receiver_metrics.state_pushed_diff_fast_path.inc();
773 }
774 }
775
776 if !pushed_diff {
777 let state_ref = cache.get_state_weak(&shard_id);
782 match state_ref {
783 None => {
784 state_refs.remove(&shard_id);
785 }
786 Some(state_ref) => {
787 if let Some(state) = state_ref.upgrade() {
788 state.push_diff(diff);
789 pushed_diff = true;
790 state_refs.insert(shard_id, state_ref);
791 } else {
792 state_refs.remove(&shard_id);
793 }
794 }
795 }
796
797 if pushed_diff {
798 receiver_metrics.state_pushed_diff_slow_path_succeeded.inc();
799 } else {
800 receiver_metrics.state_pushed_diff_slow_path_failed.inc();
801 }
802 }
803
804 if let Some(send_timestamp) = msg.timestamp {
805 let send_timestamp =
806 send_timestamp.into_rust().expect("valid timestamp");
807 let now = SystemTime::now()
808 .duration_since(SystemTime::UNIX_EPOCH)
809 .expect("failed to get millis since epoch");
810 receiver_metrics
811 .approx_diff_latency_seconds
812 .observe((now.saturating_sub(send_timestamp)).as_secs_f64());
813 }
814 }
815 ref msg @ None | ref msg @ Some(_) => {
816 warn!("pubsub client received unexpected message: {:?}", msg);
817 receiver_metrics.unknown_message_received.inc();
818 }
819 }
820 }
821 },
822 )
823}
824
825#[derive(Debug)]
827pub(crate) struct PubSubState {
828 connection_id_counter: AtomicUsize,
830 shard_subscribers:
832 Arc<RwLock<BTreeMap<ShardId, BTreeMap<usize, Sender<Result<ProtoPubSubMessage, Status>>>>>>,
833 connections: Arc<RwLock<HashSet<usize>>>,
835 metrics: Arc<PubSubServerMetrics>,
837}
838
839impl PubSubState {
840 fn new_connection(
841 self: Arc<Self>,
842 notifier: Sender<Result<ProtoPubSubMessage, Status>>,
843 ) -> PubSubConnection {
844 let connection_id = self.connection_id_counter.fetch_add(1, Ordering::SeqCst);
845 {
846 debug!("inserting connid: {}", connection_id);
847 let mut connections = self.connections.write().expect("lock");
848 assert!(connections.insert(connection_id));
849 }
850
851 self.metrics.active_connections.inc();
852 PubSubConnection {
853 connection_id,
854 notifier,
855 state: self,
856 }
857 }
858
859 fn remove_connection(&self, connection_id: usize) {
860 let now = Instant::now();
861
862 {
863 debug!("removing connid: {}", connection_id);
864 let mut connections = self.connections.write().expect("lock");
865 assert!(
866 connections.remove(&connection_id),
867 "unknown connection id: {}",
868 connection_id
869 );
870 }
871
872 {
873 let mut subscribers = self.shard_subscribers.write().expect("lock poisoned");
874 subscribers.retain(|_shard, connections_for_shard| {
875 connections_for_shard.remove(&connection_id);
876 !connections_for_shard.is_empty()
877 });
878 }
879
880 self.metrics
881 .connection_cleanup_seconds
882 .inc_by(now.elapsed().as_secs_f64());
883 self.metrics.active_connections.dec();
884 }
885
886 fn push_diff(&self, connection_id: usize, shard_id: &ShardId, data: &VersionedData) {
887 let now = Instant::now();
888 self.metrics.push_call_count.inc();
889
890 assert!(
891 self.connections
892 .read()
893 .expect("lock")
894 .contains(&connection_id),
895 "unknown connection id: {}",
896 connection_id
897 );
898
899 let subscribers = self.shard_subscribers.read().expect("lock poisoned");
900 if let Some(subscribed_connections) = subscribers.get(shard_id) {
901 let mut num_sent = 0;
902 let mut data_size = 0;
903
904 for (subscribed_conn_id, tx) in subscribed_connections {
905 if *subscribed_conn_id == connection_id {
907 continue;
908 }
909 debug!(
910 "server forwarding req to conn {}: {} {} {}",
911 subscribed_conn_id,
912 &shard_id,
913 data.seqno,
914 data.data.len()
915 );
916 let req = create_request(proto_pub_sub_message::Message::PushDiff(ProtoPushDiff {
917 seqno: data.seqno.into_proto(),
918 shard_id: shard_id.to_string(),
919 diff: Bytes::clone(&data.data),
920 }));
921 data_size = req.encoded_len();
922 match tx.try_send(Ok(req)) {
923 Ok(_) => {
924 num_sent += 1;
925 }
926 Err(TrySendError::Full(_)) => {
927 self.metrics.broadcasted_diff_dropped_channel_full.inc();
928 }
929 Err(TrySendError::Closed(_)) => {}
930 };
931 }
932
933 self.metrics.broadcasted_diff_count.inc_by(num_sent);
934 self.metrics
935 .broadcasted_diff_bytes
936 .inc_by(num_sent * u64::cast_from(data_size));
937 }
938
939 self.metrics
940 .push_seconds
941 .inc_by(now.elapsed().as_secs_f64());
942 }
943
944 fn subscribe(
945 &self,
946 connection_id: usize,
947 notifier: Sender<Result<ProtoPubSubMessage, Status>>,
948 shard_id: &ShardId,
949 ) {
950 let now = Instant::now();
951 self.metrics.subscribe_call_count.inc();
952
953 assert!(
954 self.connections
955 .read()
956 .expect("lock")
957 .contains(&connection_id),
958 "unknown connection id: {}",
959 connection_id
960 );
961
962 {
963 let mut subscribed_shards = self.shard_subscribers.write().expect("lock poisoned");
964 subscribed_shards
965 .entry(*shard_id)
966 .or_default()
967 .insert(connection_id, notifier);
968 }
969
970 self.metrics
971 .subscribe_seconds
972 .inc_by(now.elapsed().as_secs_f64());
973 }
974
975 fn unsubscribe(&self, connection_id: usize, shard_id: &ShardId) {
976 let now = Instant::now();
977 self.metrics.unsubscribe_call_count.inc();
978
979 assert!(
980 self.connections
981 .read()
982 .expect("lock")
983 .contains(&connection_id),
984 "unknown connection id: {}",
985 connection_id
986 );
987
988 {
989 let mut subscribed_shards = self.shard_subscribers.write().expect("lock poisoned");
990 if let Entry::Occupied(mut entry) = subscribed_shards.entry(*shard_id) {
991 let subscribed_connections = entry.get_mut();
992 subscribed_connections.remove(&connection_id);
993
994 if subscribed_connections.is_empty() {
995 entry.remove_entry();
996 }
997 }
998 }
999
1000 self.metrics
1001 .unsubscribe_seconds
1002 .inc_by(now.elapsed().as_secs_f64());
1003 }
1004
1005 #[cfg(test)]
1006 fn new_for_test() -> Self {
1007 Self {
1008 connection_id_counter: AtomicUsize::new(0),
1009 shard_subscribers: Default::default(),
1010 connections: Default::default(),
1011 metrics: Arc::new(PubSubServerMetrics::new(&MetricsRegistry::new())),
1012 }
1013 }
1014
1015 #[cfg(test)]
1016 fn active_connections(&self) -> HashSet<usize> {
1017 self.connections.read().expect("lock").clone()
1018 }
1019
1020 #[cfg(test)]
1021 fn subscriptions(&self, connection_id: usize) -> HashSet<ShardId> {
1022 let mut shards = HashSet::new();
1023
1024 let subscribers = self.shard_subscribers.read().expect("lock");
1025 for (shard, subscribed_connections) in subscribers.iter() {
1026 if subscribed_connections.contains_key(&connection_id) {
1027 shards.insert(*shard);
1028 }
1029 }
1030
1031 shards
1032 }
1033
1034 #[cfg(test)]
1035 fn shard_subscription_counts(&self) -> mz_ore::collections::HashMap<ShardId, usize> {
1036 let mut shards = mz_ore::collections::HashMap::new();
1037
1038 let subscribers = self.shard_subscribers.read().expect("lock");
1039 for (shard, subscribed_connections) in subscribers.iter() {
1040 shards.insert(*shard, subscribed_connections.len());
1041 }
1042
1043 shards
1044 }
1045}
1046
1047#[derive(Debug)]
1049pub struct PersistGrpcPubSubServer {
1050 cfg: PersistConfig,
1051 state: Arc<PubSubState>,
1052}
1053
1054impl PersistGrpcPubSubServer {
1055 pub fn new(cfg: &PersistConfig, metrics_registry: &MetricsRegistry) -> Self {
1057 let metrics = PubSubServerMetrics::new(metrics_registry);
1058 let state = Arc::new(PubSubState {
1059 connection_id_counter: AtomicUsize::new(0),
1060 shard_subscribers: Default::default(),
1061 connections: Default::default(),
1062 metrics: Arc::new(metrics),
1063 });
1064
1065 PersistGrpcPubSubServer {
1066 cfg: cfg.clone(),
1067 state,
1068 }
1069 }
1070
1071 pub fn new_same_process_connection(&self) -> PubSubClientConnection {
1075 let (tx, rx) =
1076 tokio::sync::mpsc::channel(PUBSUB_CLIENT_RECEIVER_CHANNEL_SIZE.get(&self.cfg));
1077 let sender: Arc<dyn PubSubSender> = Arc::new(SubscriptionTrackingSender::new(Arc::new(
1078 Arc::clone(&self.state).new_connection(tx),
1079 )));
1080
1081 PubSubClientConnection {
1082 sender,
1083 receiver: Box::new(
1084 ReceiverStream::new(rx).map(|x| x.expect("cannot receive grpc errors locally")),
1085 ),
1086 }
1087 }
1088
1089 pub async fn serve(self, listen_addr: SocketAddr) -> Result<(), anyhow::Error> {
1091 tonic::transport::Server::builder()
1093 .add_service(
1094 ProtoPersistPubSubServer::new(self)
1095 .max_decoding_message_size(MAX_GRPC_MESSAGE_SIZE),
1096 )
1097 .serve(listen_addr)
1098 .await?;
1099 Ok(())
1100 }
1101
1102 pub async fn serve_with_stream(
1105 self,
1106 listener: tokio_stream::wrappers::TcpListenerStream,
1107 ) -> Result<(), anyhow::Error> {
1108 tonic::transport::Server::builder()
1109 .add_service(
1110 ProtoPersistPubSubServer::new(self)
1111 .max_decoding_message_size(MAX_GRPC_MESSAGE_SIZE),
1112 )
1113 .serve_with_incoming(listener)
1114 .await?;
1115 Ok(())
1116 }
1117}
1118
1119#[async_trait]
1120impl proto_persist_pub_sub_server::ProtoPersistPubSub for PersistGrpcPubSubServer {
1121 type PubSubStream = Pin<Box<dyn Stream<Item = Result<ProtoPubSubMessage, Status>> + Send>>;
1122
1123 #[mz_ore::instrument(name = "persist::rpc::server", level = "info")]
1124 async fn pub_sub(
1125 &self,
1126 request: Request<Streaming<ProtoPubSubMessage>>,
1127 ) -> Result<Response<Self::PubSubStream>, Status> {
1128 let caller_id = request
1129 .metadata()
1130 .get(AsciiMetadataKey::from_static(PERSIST_PUBSUB_CALLER_KEY))
1131 .map(|key| key.to_str().ok())
1132 .flatten()
1133 .map(|key| key.to_string())
1134 .unwrap_or_else(|| "unknown".to_string());
1135 info!("Received Persist PubSub connection from: {:?}", caller_id);
1136
1137 let mut in_stream = request.into_inner();
1138 let (tx, rx) =
1139 tokio::sync::mpsc::channel(PUBSUB_SERVER_CONNECTION_CHANNEL_SIZE.get(&self.cfg));
1140
1141 let caller = caller_id.clone();
1142 let cfg = Arc::clone(&self.cfg.configs);
1143 let server_state = Arc::clone(&self.state);
1144 let connection_span = info_span!("connection", caller_id);
1148 mz_ore::task::spawn(
1149 || format!("persist_pubsub_connection({})", caller),
1150 async move {
1151 let connection = server_state.new_connection(tx);
1152 while let Some(result) = in_stream.next().await {
1153 let req = match result {
1154 Ok(req) => req,
1155 Err(err) => {
1156 warn!("pubsub connection err: {}", err);
1157 break;
1158 }
1159 };
1160
1161 match req.message {
1162 None => {
1163 warn!("received empty message from: {}", caller_id);
1164 }
1165 Some(proto_pub_sub_message::Message::PushDiff(req)) => {
1166 let shard_id = req.shard_id.parse().expect("valid shard id");
1167 let diff = VersionedData {
1168 seqno: req.seqno.into_rust().expect("valid seqno"),
1169 data: req.diff.clone(),
1170 };
1171 if PUBSUB_PUSH_DIFF_ENABLED.get(&cfg) {
1172 connection.push_diff(&shard_id, &diff);
1173 }
1174 }
1175 Some(proto_pub_sub_message::Message::Subscribe(diff)) => {
1176 let shard_id = diff.shard_id.parse().expect("valid shard id");
1177 connection.subscribe(&shard_id);
1178 }
1179 Some(proto_pub_sub_message::Message::Unsubscribe(diff)) => {
1180 let shard_id = diff.shard_id.parse().expect("valid shard id");
1181 connection.unsubscribe(&shard_id);
1182 }
1183 }
1184 }
1185
1186 info!("Persist PubSub connection ended: {:?}", caller_id);
1187 }
1188 .instrument(connection_span),
1189 );
1190
1191 let out_stream: Self::PubSubStream = Box::pin(ReceiverStream::new(rx));
1192 Ok(Response::new(out_stream))
1193 }
1194}
1195
1196#[derive(Debug)]
1200pub(crate) struct PubSubConnection {
1201 connection_id: usize,
1202 notifier: Sender<Result<ProtoPubSubMessage, Status>>,
1203 state: Arc<PubSubState>,
1204}
1205
1206impl PubSubSenderInternal for PubSubConnection {
1207 fn push_diff(&self, shard_id: &ShardId, diff: &VersionedData) {
1208 self.state.push_diff(self.connection_id, shard_id, diff)
1209 }
1210
1211 fn subscribe(&self, shard_id: &ShardId) {
1212 self.state
1213 .subscribe(self.connection_id, self.notifier.clone(), shard_id)
1214 }
1215
1216 fn unsubscribe(&self, shard_id: &ShardId) {
1217 self.state.unsubscribe(self.connection_id, shard_id)
1218 }
1219}
1220
1221impl Drop for PubSubConnection {
1222 fn drop(&mut self) {
1223 self.state.remove_connection(self.connection_id)
1224 }
1225}
1226
1227#[cfg(test)]
1228mod pubsub_state {
1229 use std::str::FromStr;
1230 use std::sync::Arc;
1231 use std::sync::LazyLock;
1232
1233 use bytes::Bytes;
1234 use mz_ore::collections::HashSet;
1235 use mz_persist::location::{SeqNo, VersionedData};
1236 use mz_proto::RustType;
1237 use tokio::sync::mpsc::Receiver;
1238 use tokio::sync::mpsc::error::TryRecvError;
1239 use tonic::Status;
1240
1241 use crate::ShardId;
1242 use crate::internal::service::ProtoPubSubMessage;
1243 use crate::internal::service::proto_pub_sub_message::Message;
1244 use crate::rpc::{PubSubSenderInternal, PubSubState};
1245
1246 static SHARD_ID_0: LazyLock<ShardId> =
1247 LazyLock::new(|| ShardId::from_str("s00000000-0000-0000-0000-000000000000").unwrap());
1248 static SHARD_ID_1: LazyLock<ShardId> =
1249 LazyLock::new(|| ShardId::from_str("s11111111-1111-1111-1111-111111111111").unwrap());
1250
1251 const VERSIONED_DATA_0: VersionedData = VersionedData {
1252 seqno: SeqNo(0),
1253 data: Bytes::from_static(&[0, 1, 2, 3]),
1254 };
1255
1256 const VERSIONED_DATA_1: VersionedData = VersionedData {
1257 seqno: SeqNo(1),
1258 data: Bytes::from_static(&[4, 5, 6, 7]),
1259 };
1260
1261 #[mz_ore::test]
1262 #[should_panic(expected = "unknown connection id: 100")]
1263 fn test_zero_connections_push_diff() {
1264 let state = Arc::new(PubSubState::new_for_test());
1265 state.push_diff(100, &SHARD_ID_0, &VERSIONED_DATA_0);
1266 }
1267
1268 #[mz_ore::test]
1269 #[should_panic(expected = "unknown connection id: 100")]
1270 fn test_zero_connections_subscribe() {
1271 let state = Arc::new(PubSubState::new_for_test());
1272 let (tx, _) = tokio::sync::mpsc::channel(100);
1273 state.subscribe(100, tx, &SHARD_ID_0);
1274 }
1275
1276 #[mz_ore::test]
1277 #[should_panic(expected = "unknown connection id: 100")]
1278 fn test_zero_connections_unsubscribe() {
1279 let state = Arc::new(PubSubState::new_for_test());
1280 state.unsubscribe(100, &SHARD_ID_0);
1281 }
1282
1283 #[mz_ore::test]
1284 #[should_panic(expected = "unknown connection id: 100")]
1285 fn test_zero_connections_remove() {
1286 let state = Arc::new(PubSubState::new_for_test());
1287 state.remove_connection(100)
1288 }
1289
1290 #[mz_ore::test]
1291 fn test_single_connection() {
1292 let state = Arc::new(PubSubState::new_for_test());
1293
1294 let (tx, mut rx) = tokio::sync::mpsc::channel(100);
1295 let connection = Arc::clone(&state).new_connection(tx);
1296
1297 assert_eq!(
1298 state.active_connections(),
1299 HashSet::from([connection.connection_id])
1300 );
1301
1302 assert!(matches!(rx.try_recv(), Err(TryRecvError::Empty)));
1304
1305 connection.push_diff(
1306 &SHARD_ID_0,
1307 &VersionedData {
1308 seqno: SeqNo::minimum(),
1309 data: Bytes::new(),
1310 },
1311 );
1312
1313 assert!(matches!(rx.try_recv(), Err(TryRecvError::Empty)));
1315
1316 connection.subscribe(&SHARD_ID_0);
1318 assert_eq!(
1319 state.subscriptions(connection.connection_id),
1320 HashSet::from([SHARD_ID_0.clone()])
1321 );
1322
1323 connection.unsubscribe(&SHARD_ID_0);
1325 assert!(state.subscriptions(connection.connection_id).is_empty());
1326
1327 connection.subscribe(&SHARD_ID_0);
1329 connection.subscribe(&SHARD_ID_1);
1330 assert_eq!(
1331 state.subscriptions(connection.connection_id),
1332 HashSet::from([*SHARD_ID_0, *SHARD_ID_1])
1333 );
1334
1335 connection.subscribe(&SHARD_ID_0);
1337 connection.subscribe(&SHARD_ID_0);
1338 assert_eq!(
1339 state.subscriptions(connection.connection_id),
1340 HashSet::from([*SHARD_ID_0, *SHARD_ID_1])
1341 );
1342
1343 let connection_id = connection.connection_id;
1345 drop(connection);
1346 assert!(state.subscriptions(connection_id).is_empty());
1347 assert!(state.active_connections().is_empty());
1348 }
1349
1350 #[mz_ore::test]
1351 fn test_many_connection() {
1352 let state = Arc::new(PubSubState::new_for_test());
1353
1354 let (tx1, mut rx1) = tokio::sync::mpsc::channel(100);
1355 let conn1 = Arc::clone(&state).new_connection(tx1);
1356
1357 let (tx2, mut rx2) = tokio::sync::mpsc::channel(100);
1358 let conn2 = Arc::clone(&state).new_connection(tx2);
1359
1360 let (tx3, mut rx3) = tokio::sync::mpsc::channel(100);
1361 let conn3 = Arc::clone(&state).new_connection(tx3);
1362
1363 conn1.subscribe(&SHARD_ID_0);
1364 conn2.subscribe(&SHARD_ID_0);
1365 conn2.subscribe(&SHARD_ID_1);
1366
1367 assert_eq!(
1368 state.active_connections(),
1369 HashSet::from([
1370 conn1.connection_id,
1371 conn2.connection_id,
1372 conn3.connection_id
1373 ])
1374 );
1375
1376 conn3.push_diff(&SHARD_ID_0, &VERSIONED_DATA_0);
1378 assert_push(&mut rx1, &SHARD_ID_0, &VERSIONED_DATA_0);
1379 assert_push(&mut rx2, &SHARD_ID_0, &VERSIONED_DATA_0);
1380 assert!(matches!(rx3.try_recv(), Err(TryRecvError::Empty)));
1381
1382 conn1.push_diff(&SHARD_ID_0, &VERSIONED_DATA_0);
1384 assert!(matches!(rx1.try_recv(), Err(TryRecvError::Empty)));
1385 assert_push(&mut rx2, &SHARD_ID_0, &VERSIONED_DATA_0);
1386 assert!(matches!(rx3.try_recv(), Err(TryRecvError::Empty)));
1387
1388 conn3.push_diff(&SHARD_ID_1, &VERSIONED_DATA_1);
1390 assert!(matches!(rx1.try_recv(), Err(TryRecvError::Empty)));
1391 assert_push(&mut rx2, &SHARD_ID_1, &VERSIONED_DATA_1);
1392 assert!(matches!(rx3.try_recv(), Err(TryRecvError::Empty)));
1393
1394 conn2.unsubscribe(&SHARD_ID_1);
1396 conn3.push_diff(&SHARD_ID_1, &VERSIONED_DATA_1);
1397 assert!(matches!(rx1.try_recv(), Err(TryRecvError::Empty)));
1398 assert!(matches!(rx2.try_recv(), Err(TryRecvError::Empty)));
1399 assert!(matches!(rx3.try_recv(), Err(TryRecvError::Empty)));
1400
1401 let conn1_id = conn1.connection_id;
1403 drop(conn1);
1404 conn3.push_diff(&SHARD_ID_0, &VERSIONED_DATA_0);
1405 assert!(matches!(rx1.try_recv(), Err(TryRecvError::Disconnected)));
1406 assert_push(&mut rx2, &SHARD_ID_0, &VERSIONED_DATA_0);
1407 assert!(matches!(rx3.try_recv(), Err(TryRecvError::Empty)));
1408
1409 assert!(state.subscriptions(conn1_id).is_empty());
1410 assert_eq!(
1411 state.subscriptions(conn2.connection_id),
1412 HashSet::from([*SHARD_ID_0])
1413 );
1414 assert_eq!(state.subscriptions(conn3.connection_id), HashSet::new());
1415 assert_eq!(
1416 state.active_connections(),
1417 HashSet::from([conn2.connection_id, conn3.connection_id])
1418 );
1419 }
1420
1421 fn assert_push(
1422 rx: &mut Receiver<Result<ProtoPubSubMessage, Status>>,
1423 shard: &ShardId,
1424 data: &VersionedData,
1425 ) {
1426 let message = rx
1427 .try_recv()
1428 .expect("message in channel")
1429 .expect("pubsub")
1430 .message
1431 .expect("proto contains message");
1432 match message {
1433 Message::PushDiff(x) => {
1434 assert_eq!(x.shard_id, shard.into_proto());
1435 assert_eq!(x.seqno, data.seqno.into_proto());
1436 assert_eq!(x.diff, data.data);
1437 }
1438 Message::Subscribe(_) | Message::Unsubscribe(_) => panic!("unexpected message type"),
1439 };
1440 }
1441}
1442
1443#[cfg(test)]
1444mod grpc {
1445 use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
1446 use std::str::FromStr;
1447 use std::sync::Arc;
1448 use std::time::{Duration, Instant};
1449
1450 use bytes::Bytes;
1451 use futures_util::FutureExt;
1452 use mz_dyncfg::ConfigUpdates;
1453 use mz_ore::assert_none;
1454 use mz_ore::collections::HashMap;
1455 use mz_ore::metrics::MetricsRegistry;
1456 use mz_persist::location::{SeqNo, VersionedData};
1457 use mz_proto::RustType;
1458 use std::sync::LazyLock;
1459 use tokio::net::TcpListener;
1460 use tokio_stream::StreamExt;
1461 use tokio_stream::wrappers::TcpListenerStream;
1462
1463 use crate::ShardId;
1464 use crate::cfg::PersistConfig;
1465 use crate::internal::service::ProtoPubSubMessage;
1466 use crate::internal::service::proto_pub_sub_message::Message;
1467 use crate::metrics::Metrics;
1468 use crate::rpc::{
1469 GrpcPubSubClient, PUBSUB_CLIENT_ENABLED, PUBSUB_RECONNECT_BACKOFF, PersistGrpcPubSubServer,
1470 PersistPubSubClient, PersistPubSubClientConfig, PubSubState,
1471 };
1472
1473 static SHARD_ID_0: LazyLock<ShardId> =
1474 LazyLock::new(|| ShardId::from_str("s00000000-0000-0000-0000-000000000000").unwrap());
1475 static SHARD_ID_1: LazyLock<ShardId> =
1476 LazyLock::new(|| ShardId::from_str("s11111111-1111-1111-1111-111111111111").unwrap());
1477 const VERSIONED_DATA_0: VersionedData = VersionedData {
1478 seqno: SeqNo(0),
1479 data: Bytes::from_static(&[0, 1, 2, 3]),
1480 };
1481 const VERSIONED_DATA_1: VersionedData = VersionedData {
1482 seqno: SeqNo(1),
1483 data: Bytes::from_static(&[4, 5, 6, 7]),
1484 };
1485
1486 const CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
1487 const SUBSCRIPTIONS_TIMEOUT: Duration = Duration::from_secs(3);
1488 const SERVER_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(2);
1489
1490 #[mz_ore::test]
1495 #[cfg_attr(miri, ignore)] fn grpc_server() {
1497 let metrics = Arc::new(Metrics::new(
1498 &test_persist_config(),
1499 &MetricsRegistry::new(),
1500 ));
1501 let server_runtime = tokio::runtime::Runtime::new().expect("server runtime");
1502 let client_runtime = tokio::runtime::Runtime::new().expect("client runtime");
1503
1504 let (addr, tcp_listener_stream) = server_runtime.block_on(new_tcp_listener());
1506 let server_state = server_runtime.block_on(spawn_server(tcp_listener_stream));
1507
1508 {
1510 let _guard = client_runtime.enter();
1511 mz_ore::task::spawn(|| "client".to_string(), async move {
1512 let client = GrpcPubSubClient::connect(
1513 PersistPubSubClientConfig {
1514 url: format!("http://{}", addr),
1515 caller_id: "client".to_string(),
1516 persist_cfg: test_persist_config(),
1517 },
1518 metrics,
1519 );
1520 let _token = client.sender.subscribe(&SHARD_ID_0);
1521 tokio::time::sleep(Duration::MAX).await;
1522 });
1523 }
1524
1525 server_runtime.block_on(async {
1527 poll_until_true(CONNECT_TIMEOUT, || {
1528 server_state.active_connections().len() == 1
1529 })
1530 .await;
1531 poll_until_true(SUBSCRIPTIONS_TIMEOUT, || {
1532 server_state.shard_subscription_counts() == HashMap::from([(*SHARD_ID_0, 1)])
1533 })
1534 .await
1535 });
1536
1537 client_runtime.shutdown_timeout(SERVER_SHUTDOWN_TIMEOUT);
1539
1540 server_runtime.block_on(async {
1542 poll_until_true(CONNECT_TIMEOUT, || {
1543 server_state.active_connections().is_empty()
1544 })
1545 .await;
1546 poll_until_true(SUBSCRIPTIONS_TIMEOUT, || {
1547 server_state.shard_subscription_counts() == HashMap::new()
1548 })
1549 .await
1550 });
1551 }
1552
1553 #[mz_ore::test]
1554 #[cfg_attr(miri, ignore)] fn grpc_client_sender_reconnects() {
1556 let metrics = Arc::new(Metrics::new(
1557 &test_persist_config(),
1558 &MetricsRegistry::new(),
1559 ));
1560 let server_runtime = tokio::runtime::Runtime::new().expect("server runtime");
1561 let client_runtime = tokio::runtime::Runtime::new().expect("client runtime");
1562 let (addr, tcp_listener_stream) = server_runtime.block_on(new_tcp_listener());
1563
1564 let client = client_runtime.block_on(async {
1566 GrpcPubSubClient::connect(
1567 PersistPubSubClientConfig {
1568 url: format!("http://{}", addr),
1569 caller_id: "client".to_string(),
1570 persist_cfg: test_persist_config(),
1571 },
1572 metrics,
1573 )
1574 });
1575
1576 let _token = Arc::clone(&client.sender).subscribe(&SHARD_ID_0);
1578 let _token_2 = Arc::clone(&client.sender).subscribe(&SHARD_ID_1);
1580 drop(_token_2);
1581
1582 let server_state = server_runtime.block_on(spawn_server(tcp_listener_stream));
1584
1585 server_runtime.block_on(async {
1586 poll_until_true(CONNECT_TIMEOUT, || {
1588 server_state.active_connections().len() == 1
1589 })
1590 .await;
1591
1592 poll_until_true(SUBSCRIPTIONS_TIMEOUT, || {
1595 server_state.shard_subscription_counts() == HashMap::from([(*SHARD_ID_0, 1)])
1596 })
1597 .await;
1598 });
1599
1600 server_runtime.shutdown_timeout(SERVER_SHUTDOWN_TIMEOUT);
1602
1603 let _token_2 = Arc::clone(&client.sender).subscribe(&SHARD_ID_1);
1605
1606 let server_runtime = tokio::runtime::Runtime::new().expect("server runtime");
1608 let tcp_listener_stream = server_runtime.block_on(async {
1609 TcpListenerStream::new(
1610 TcpListener::bind(addr)
1611 .await
1612 .expect("can bind to previous addr"),
1613 )
1614 });
1615 let server_state = server_runtime.block_on(spawn_server(tcp_listener_stream));
1616
1617 server_runtime.block_on(async {
1618 poll_until_true(CONNECT_TIMEOUT, || {
1620 server_state.active_connections().len() == 1
1621 })
1622 .await;
1623
1624 poll_until_true(SUBSCRIPTIONS_TIMEOUT, || {
1627 server_state.shard_subscription_counts()
1628 == HashMap::from([(*SHARD_ID_0, 1), (*SHARD_ID_1, 1)])
1629 })
1630 .await;
1631 });
1632 }
1633
1634 #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
1635 #[cfg_attr(miri, ignore)] async fn grpc_client_sender_subscription_tokens() {
1637 let metrics = Arc::new(Metrics::new(
1638 &test_persist_config(),
1639 &MetricsRegistry::new(),
1640 ));
1641
1642 let (addr, tcp_listener_stream) = new_tcp_listener().await;
1643 let server_state = spawn_server(tcp_listener_stream).await;
1644
1645 let client = GrpcPubSubClient::connect(
1646 PersistPubSubClientConfig {
1647 url: format!("http://{}", addr),
1648 caller_id: "client".to_string(),
1649 persist_cfg: test_persist_config(),
1650 },
1651 metrics,
1652 );
1653
1654 poll_until_true(CONNECT_TIMEOUT, || {
1656 server_state.active_connections().len() == 1
1657 })
1658 .await;
1659
1660 let token = Arc::clone(&client.sender).subscribe(&SHARD_ID_0);
1662 poll_until_true(SUBSCRIPTIONS_TIMEOUT, || {
1663 server_state.shard_subscription_counts() == HashMap::from([(*SHARD_ID_0, 1)])
1664 })
1665 .await;
1666
1667 drop(token);
1669 poll_until_true(SUBSCRIPTIONS_TIMEOUT, || {
1670 server_state.shard_subscription_counts() == HashMap::new()
1671 })
1672 .await;
1673
1674 let token = Arc::clone(&client.sender).subscribe(&SHARD_ID_0);
1676 poll_until_true(SUBSCRIPTIONS_TIMEOUT, || {
1677 server_state.shard_subscription_counts() == HashMap::from([(*SHARD_ID_0, 1)])
1678 })
1679 .await;
1680
1681 let token2 = Arc::clone(&client.sender).subscribe(&SHARD_ID_0);
1683 let token3 = Arc::clone(&client.sender).subscribe(&SHARD_ID_0);
1684 assert_eq!(Arc::strong_count(&token), 3);
1685 poll_until_true(SUBSCRIPTIONS_TIMEOUT, || {
1686 server_state.shard_subscription_counts() == HashMap::from([(*SHARD_ID_0, 1)])
1687 })
1688 .await;
1689
1690 drop(token);
1692 drop(token2);
1693 drop(token3);
1694 poll_until_true(SUBSCRIPTIONS_TIMEOUT, || {
1695 server_state.shard_subscription_counts() == HashMap::new()
1696 })
1697 .await;
1698
1699 let _token0 = Arc::clone(&client.sender).subscribe(&SHARD_ID_0);
1701 let _token1 = Arc::clone(&client.sender).subscribe(&SHARD_ID_1);
1702 poll_until_true(SUBSCRIPTIONS_TIMEOUT, || {
1703 server_state.shard_subscription_counts()
1704 == HashMap::from([(*SHARD_ID_0, 1), (*SHARD_ID_1, 1)])
1705 })
1706 .await;
1707 }
1708
1709 #[mz_ore::test]
1710 #[cfg_attr(miri, ignore)] fn grpc_client_receiver() {
1712 let metrics = Arc::new(Metrics::new(
1713 &PersistConfig::new_for_tests(),
1714 &MetricsRegistry::new(),
1715 ));
1716 let server_runtime = tokio::runtime::Runtime::new().expect("server runtime");
1717 let client_runtime = tokio::runtime::Runtime::new().expect("client runtime");
1718 let (addr, tcp_listener_stream) = server_runtime.block_on(new_tcp_listener());
1719
1720 let mut client_1 = client_runtime.block_on(async {
1722 GrpcPubSubClient::connect(
1723 PersistPubSubClientConfig {
1724 url: format!("http://{}", addr),
1725 caller_id: "client_1".to_string(),
1726 persist_cfg: test_persist_config(),
1727 },
1728 Arc::clone(&metrics),
1729 )
1730 });
1731 let mut client_2 = client_runtime.block_on(async {
1732 GrpcPubSubClient::connect(
1733 PersistPubSubClientConfig {
1734 url: format!("http://{}", addr),
1735 caller_id: "client_2".to_string(),
1736 persist_cfg: test_persist_config(),
1737 },
1738 metrics,
1739 )
1740 });
1741
1742 assert_none!(client_1.receiver.next().now_or_never());
1747 assert_none!(client_2.receiver.next().now_or_never());
1748
1749 let server_state = server_runtime.block_on(spawn_server(tcp_listener_stream));
1751
1752 server_runtime.block_on(poll_until_true(CONNECT_TIMEOUT, || {
1754 server_state.active_connections().len() == 2
1755 }));
1756
1757 assert_none!(client_1.receiver.next().now_or_never());
1759 assert_none!(client_2.receiver.next().now_or_never());
1760
1761 let _token_client_1 = Arc::clone(&client_1.sender).subscribe(&SHARD_ID_0);
1763 let _token_client_2 = Arc::clone(&client_2.sender).subscribe(&SHARD_ID_0);
1764 server_runtime.block_on(poll_until_true(SUBSCRIPTIONS_TIMEOUT, || {
1765 server_state.shard_subscription_counts() == HashMap::from([(*SHARD_ID_0, 2)])
1766 }));
1767
1768 client_1.sender.push_diff(&SHARD_ID_0, &VERSIONED_DATA_1);
1770 assert_none!(client_1.receiver.next().now_or_never());
1771 client_runtime.block_on(async {
1772 assert_push(
1773 client_2.receiver.next().await.expect("has diff"),
1774 &SHARD_ID_0,
1775 &VERSIONED_DATA_1,
1776 )
1777 });
1778
1779 server_runtime.shutdown_timeout(SERVER_SHUTDOWN_TIMEOUT);
1781
1782 assert_none!(client_1.receiver.next().now_or_never());
1784 assert_none!(client_2.receiver.next().now_or_never());
1785
1786 let server_runtime = tokio::runtime::Runtime::new().expect("server runtime");
1788 let tcp_listener_stream = server_runtime.block_on(async {
1789 TcpListenerStream::new(
1790 TcpListener::bind(addr)
1791 .await
1792 .expect("can bind to previous addr"),
1793 )
1794 });
1795 let server_state = server_runtime.block_on(spawn_server(tcp_listener_stream));
1796
1797 server_runtime.block_on(async {
1799 poll_until_true(CONNECT_TIMEOUT, || {
1800 server_state.active_connections().len() == 2
1801 })
1802 .await;
1803 poll_until_true(SUBSCRIPTIONS_TIMEOUT, || {
1804 server_state.shard_subscription_counts() == HashMap::from([(*SHARD_ID_0, 2)])
1805 })
1806 .await;
1807 });
1808
1809 client_2.sender.push_diff(&SHARD_ID_0, &VERSIONED_DATA_0);
1812 client_runtime.block_on(async {
1813 assert_push(
1814 client_1.receiver.next().await.expect("has diff"),
1815 &SHARD_ID_0,
1816 &VERSIONED_DATA_0,
1817 )
1818 });
1819 assert_none!(client_2.receiver.next().now_or_never());
1820 }
1821
1822 async fn new_tcp_listener() -> (SocketAddr, TcpListenerStream) {
1823 let addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0));
1824 let tcp_listener = TcpListener::bind(addr).await.expect("tcp listener");
1825
1826 (
1827 tcp_listener.local_addr().expect("bound to local address"),
1828 TcpListenerStream::new(tcp_listener),
1829 )
1830 }
1831
1832 #[allow(clippy::unused_async)]
1833 async fn spawn_server(tcp_listener_stream: TcpListenerStream) -> Arc<PubSubState> {
1834 let server = PersistGrpcPubSubServer::new(&test_persist_config(), &MetricsRegistry::new());
1835 let server_state = Arc::clone(&server.state);
1836
1837 let _server_task = mz_ore::task::spawn(|| "server".to_string(), async move {
1838 server.serve_with_stream(tcp_listener_stream).await
1839 });
1840 server_state
1841 }
1842
1843 async fn poll_until_true<F>(timeout: Duration, f: F)
1844 where
1845 F: Fn() -> bool,
1846 {
1847 let now = Instant::now();
1848 loop {
1849 if f() {
1850 return;
1851 }
1852
1853 if now.elapsed() > timeout {
1854 panic!("timed out");
1855 }
1856
1857 tokio::time::sleep(Duration::from_millis(1)).await;
1858 }
1859 }
1860
1861 fn assert_push(message: ProtoPubSubMessage, shard: &ShardId, data: &VersionedData) {
1862 let message = message.message.expect("proto contains message");
1863 match message {
1864 Message::PushDiff(x) => {
1865 assert_eq!(x.shard_id, shard.into_proto());
1866 assert_eq!(x.seqno, data.seqno.into_proto());
1867 assert_eq!(x.diff, data.data);
1868 }
1869 Message::Subscribe(_) | Message::Unsubscribe(_) => panic!("unexpected message type"),
1870 };
1871 }
1872
1873 fn test_persist_config() -> PersistConfig {
1874 let cfg = PersistConfig::new_for_tests();
1875
1876 let mut updates = ConfigUpdates::default();
1877 updates.add(&PUBSUB_CLIENT_ENABLED, true);
1878 updates.add(&PUBSUB_RECONNECT_BACKOFF, Duration::ZERO);
1879 cfg.apply_from(&updates);
1880
1881 cfg
1882 }
1883}