1use std::collections::BTreeMap;
13use std::collections::btree_map::Entry;
14use std::fmt::{Debug, Formatter};
15use std::net::SocketAddr;
16use std::pin::Pin;
17use std::str::FromStr;
18use std::sync::atomic::{AtomicUsize, Ordering};
19use std::sync::{Arc, Mutex, RwLock, Weak};
20use std::time::{Duration, Instant, SystemTime};
21
22use anyhow::{Error, anyhow};
23use async_trait::async_trait;
24use bytes::Bytes;
25use futures::Stream;
26use futures_util::StreamExt;
27use mz_dyncfg::Config;
28use mz_ore::cast::CastFrom;
29use mz_ore::collections::{HashMap, HashSet};
30use mz_ore::metrics::MetricsRegistry;
31use mz_ore::retry::RetryResult;
32use mz_ore::task::JoinHandle;
33use mz_persist::location::VersionedData;
34use mz_proto::{ProtoType, RustType};
35use prost::Message;
36use tokio::sync::mpsc::Sender;
37use tokio::sync::mpsc::error::TrySendError;
38use tokio_stream::wrappers::errors::BroadcastStreamRecvError;
39use tokio_stream::wrappers::{BroadcastStream, ReceiverStream};
40use tonic::metadata::{AsciiMetadataKey, AsciiMetadataValue, MetadataMap};
41use tonic::transport::Endpoint;
42use tonic::{Extensions, Request, Response, Status, Streaming};
43use tracing::{Instrument, debug, error, info, info_span, warn};
44
45use crate::ShardId;
46use crate::cache::{DynState, StateCache};
47use crate::cfg::PersistConfig;
48use crate::internal::metrics::{PubSubClientCallMetrics, PubSubServerMetrics};
49use crate::internal::service::proto_persist_pub_sub_client::ProtoPersistPubSubClient;
50use crate::internal::service::proto_persist_pub_sub_server::ProtoPersistPubSubServer;
51use crate::internal::service::{
52 ProtoPubSubMessage, ProtoPushDiff, ProtoSubscribe, ProtoUnsubscribe,
53 proto_persist_pub_sub_server, proto_pub_sub_message,
54};
55use crate::metrics::Metrics;
56
57pub(crate) const PUBSUB_CLIENT_ENABLED: Config<bool> = Config::new(
59 "persist_pubsub_client_enabled",
60 true,
61 "Whether to connect to the Persist PubSub service.",
62);
63
64pub(crate) const PUBSUB_PUSH_DIFF_ENABLED: Config<bool> = Config::new(
68 "persist_pubsub_push_diff_enabled",
69 true,
70 "Whether to push state diffs to Persist PubSub.",
71);
72
73pub(crate) const PUBSUB_SAME_PROCESS_DELEGATE_ENABLED: Config<bool> = Config::new(
77 "persist_pubsub_same_process_delegate_enabled",
78 true,
79 "Whether to push state diffs to Persist PubSub on the same process.",
80);
81
82pub(crate) const PUBSUB_CONNECT_ATTEMPT_TIMEOUT: Config<Duration> = Config::new(
84 "persist_pubsub_connect_attempt_timeout",
85 Duration::from_secs(5),
86 "Timeout per connection attempt to Persist PubSub service.",
87);
88
89pub(crate) const PUBSUB_REQUEST_TIMEOUT: Config<Duration> = Config::new(
91 "persist_pubsub_request_timeout",
92 Duration::from_secs(5),
93 "Timeout per request attempt to Persist PubSub service.",
94);
95
96pub(crate) const PUBSUB_CONNECT_MAX_BACKOFF: Config<Duration> = Config::new(
98 "persist_pubsub_connect_max_backoff",
99 Duration::from_secs(60),
100 "Maximum backoff when retrying connection establishment to Persist PubSub service.",
101);
102
103pub(crate) const PUBSUB_CLIENT_SENDER_CHANNEL_SIZE: Config<usize> = Config::new(
105 "persist_pubsub_client_sender_channel_size",
106 25,
107 "Size of channel used to buffer send messages to PubSub service.",
108);
109
110pub(crate) const PUBSUB_CLIENT_RECEIVER_CHANNEL_SIZE: Config<usize> = Config::new(
112 "persist_pubsub_client_receiver_channel_size",
113 25,
114 "Size of channel used to buffer received messages from PubSub service.",
115);
116
117pub(crate) const PUBSUB_SERVER_CONNECTION_CHANNEL_SIZE: Config<usize> = Config::new(
119 "persist_pubsub_server_connection_channel_size",
120 25,
121 "Size of channel used per connection to buffer broadcasted messages from PubSub server.",
122);
123
124pub(crate) const PUBSUB_STATE_CACHE_SHARD_REF_CHANNEL_SIZE: Config<usize> = Config::new(
126 "persist_pubsub_state_cache_shard_ref_channel_size",
127 25,
128 "Size of channel used by the state cache to broadcast shard state references.",
129);
130
131pub(crate) const PUBSUB_RECONNECT_BACKOFF: Config<Duration> = Config::new(
133 "persist_pubsub_reconnect_backoff",
134 Duration::from_secs(5),
135 "Backoff after an established connection to Persist PubSub service fails.",
136);
137
138pub trait PersistPubSubClient {
143 fn connect(
145 pubsub_config: PersistPubSubClientConfig,
146 metrics: Arc<Metrics>,
147 ) -> PubSubClientConnection;
148}
149
150#[derive(Debug)]
152pub struct PubSubClientConnection {
153 pub sender: Arc<dyn PubSubSender>,
155 pub receiver: Box<dyn PubSubReceiver>,
157}
158
159impl PubSubClientConnection {
160 pub fn new(sender: Arc<dyn PubSubSender>, receiver: Box<dyn PubSubReceiver>) -> Self {
162 Self { sender, receiver }
163 }
164
165 pub fn noop() -> Self {
167 Self {
168 sender: Arc::new(NoopPubSubSender),
169 receiver: Box::new(futures::stream::empty()),
170 }
171 }
172}
173
174pub trait PubSubSender: std::fmt::Debug + Send + Sync {
176 fn push_diff(&self, shard_id: &ShardId, diff: &VersionedData);
178
179 fn subscribe(self: Arc<Self>, shard_id: &ShardId) -> Arc<ShardSubscriptionToken>;
187}
188
189trait PubSubSenderInternal: Debug + Send + Sync {
194 fn push_diff(&self, shard_id: &ShardId, diff: &VersionedData);
196
197 fn subscribe(&self, shard_id: &ShardId);
201
202 fn unsubscribe(&self, shard_id: &ShardId);
206}
207
208pub trait PubSubReceiver:
213 Stream<Item = ProtoPubSubMessage> + Send + Unpin + std::fmt::Debug
214{
215}
216
217impl<T> PubSubReceiver for T where
218 T: Stream<Item = ProtoPubSubMessage> + Send + Unpin + std::fmt::Debug
219{
220}
221
222pub struct ShardSubscriptionToken {
227 pub(crate) shard_id: ShardId,
228 sender: Arc<dyn PubSubSenderInternal>,
229}
230
231impl Debug for ShardSubscriptionToken {
232 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
233 let ShardSubscriptionToken {
234 shard_id,
235 sender: _sender,
236 } = self;
237 write!(f, "ShardSubscriptionToken({})", shard_id)
238 }
239}
240
241impl Drop for ShardSubscriptionToken {
242 fn drop(&mut self) {
243 self.sender.unsubscribe(&self.shard_id);
244 }
245}
246
247pub const PERSIST_PUBSUB_CALLER_KEY: &str = "persist-pubsub-caller-id";
249
250#[derive(Debug)]
252pub struct PersistPubSubClientConfig {
253 pub url: String,
255 pub caller_id: String,
257 pub persist_cfg: PersistConfig,
259}
260
261#[derive(Debug)]
267pub struct GrpcPubSubClient;
268
269impl GrpcPubSubClient {
270 async fn reconnect_to_server_forever(
271 send_requests: tokio::sync::broadcast::Sender<ProtoPubSubMessage>,
272 receiver_input: &tokio::sync::mpsc::Sender<ProtoPubSubMessage>,
273 sender: Arc<SubscriptionTrackingSender>,
274 metadata: MetadataMap,
275 config: PersistPubSubClientConfig,
276 metrics: Arc<Metrics>,
277 ) {
278 config.persist_cfg.configs_synced_once().await;
283
284 let mut is_first_connection_attempt = true;
285 loop {
286 let sender = Arc::clone(&sender);
287 metrics.pubsub_client.grpc_connection.connected.set(0);
288
289 if !PUBSUB_CLIENT_ENABLED.get(&config.persist_cfg) {
290 tokio::time::sleep(Duration::from_secs(5)).await;
291 continue;
292 }
293
294 if is_first_connection_attempt {
296 is_first_connection_attempt = false;
297 } else {
298 tokio::time::sleep(PUBSUB_RECONNECT_BACKOFF.get(&config.persist_cfg)).await;
299 }
300
301 info!("Connecting to Persist PubSub: {}", config.url);
302 let client = mz_ore::retry::Retry::default()
303 .clamp_backoff(PUBSUB_CONNECT_MAX_BACKOFF.get(&config.persist_cfg))
304 .retry_async(|_| async {
305 metrics
306 .pubsub_client
307 .grpc_connection
308 .connect_call_attempt_count
309 .inc();
310 let endpoint = match Endpoint::from_str(&config.url) {
311 Ok(endpoint) => endpoint,
312 Err(err) => return RetryResult::FatalErr(err),
313 };
314 ProtoPersistPubSubClient::connect(
315 endpoint
316 .connect_timeout(
317 PUBSUB_CONNECT_ATTEMPT_TIMEOUT.get(&config.persist_cfg),
318 )
319 .timeout(PUBSUB_REQUEST_TIMEOUT.get(&config.persist_cfg)),
320 )
321 .await
322 .into()
323 })
324 .await;
325
326 let mut client = match client {
327 Ok(client) => client,
328 Err(err) => {
329 error!("fatal error connecting to persist pubsub: {:?}", err);
330 return;
331 }
332 };
333
334 metrics
335 .pubsub_client
336 .grpc_connection
337 .connection_established_count
338 .inc();
339 metrics.pubsub_client.grpc_connection.connected.set(1);
340
341 info!("Connected to Persist PubSub: {}", config.url);
342
343 let mut broadcast = BroadcastStream::new(send_requests.subscribe());
344 let broadcast_errors = metrics
345 .pubsub_client
346 .grpc_connection
347 .broadcast_recv_lagged_count
348 .clone();
349
350 let broadcast_messages = async_stream::stream! {
354 'reconnect: loop {
355 for id in sender.subscriptions() {
357 debug!("re-subscribing to shard: {id}");
358 let msg = proto_pub_sub_message::Message::Subscribe(
359 ProtoSubscribe {
360 shard_id: id.into_proto(),
361 },
362 );
363 yield create_request(msg);
364 }
365
366 while let Some(message) = broadcast.next().await {
368 debug!("sending pubsub message: {:?}", message);
369 match message {
370 Ok(message) => yield message,
371 Err(BroadcastStreamRecvError::Lagged(i)) => {
372 broadcast_errors.inc_by(i);
373 continue 'reconnect;
374 }
375 }
376 }
377 debug!("exhausted pubsub broadcast stream; shutting down");
378 break;
379 }
380 };
381 let pubsub_request =
382 Request::from_parts(metadata.clone(), Extensions::default(), broadcast_messages);
383
384 let responses = match client.pub_sub(pubsub_request).await {
385 Ok(response) => response.into_inner(),
386 Err(err) => {
387 warn!("pub_sub rpc error: {:?}", err);
388 continue;
389 }
390 };
391
392 let stream_completed = GrpcPubSubClient::consume_grpc_stream(
393 responses,
394 receiver_input,
395 &config,
396 metrics.as_ref(),
397 )
398 .await;
399
400 match stream_completed {
401 Ok(_) => continue,
403 Err(err) => {
406 warn!("shutting down connection loop to Persist PubSub: {}", err);
407 return;
408 }
409 }
410 }
411 }
412
413 async fn consume_grpc_stream(
414 mut responses: Streaming<ProtoPubSubMessage>,
415 receiver_input: &Sender<ProtoPubSubMessage>,
416 config: &PersistPubSubClientConfig,
417 metrics: &Metrics,
418 ) -> Result<(), Error> {
419 loop {
420 if !PUBSUB_CLIENT_ENABLED.get(&config.persist_cfg) {
421 return Ok(());
422 }
423
424 debug!("awaiting next pubsub response");
425 match responses.next().await {
426 Some(Ok(message)) => {
427 debug!("received pubsub message: {:?}", message);
428 match receiver_input.send(message).await {
429 Ok(_) => {}
430 Err(err) => {
433 return Err(anyhow!("closing pubsub grpc client connection: {}", err));
434 }
435 }
436 }
437 Some(Err(err)) => {
438 metrics.pubsub_client.grpc_connection.grpc_error_count.inc();
439 warn!("pubsub client error: {:?}", err);
440 return Ok(());
441 }
442 None => return Ok(()),
443 }
444 }
445 }
446}
447
448impl PersistPubSubClient for GrpcPubSubClient {
449 fn connect(config: PersistPubSubClientConfig, metrics: Arc<Metrics>) -> PubSubClientConnection {
450 let (send_requests, _) = tokio::sync::broadcast::channel(
455 PUBSUB_CLIENT_SENDER_CHANNEL_SIZE.get(&config.persist_cfg),
456 );
457 let (receiver_input, receiver_output) = tokio::sync::mpsc::channel(
461 PUBSUB_CLIENT_RECEIVER_CHANNEL_SIZE.get(&config.persist_cfg),
462 );
463
464 let sender = Arc::new(SubscriptionTrackingSender::new(Arc::new(
465 GrpcPubSubSender {
466 metrics: Arc::clone(&metrics),
467 requests: send_requests.clone(),
468 },
469 )));
470 let pubsub_sender = Arc::clone(&sender);
471 mz_ore::task::spawn(
472 || "persist::rpc::client::connection".to_string(),
473 async move {
474 let mut metadata = MetadataMap::new();
475 metadata.insert(
476 AsciiMetadataKey::from_static(PERSIST_PUBSUB_CALLER_KEY),
477 AsciiMetadataValue::try_from(&config.caller_id)
478 .unwrap_or_else(|_| AsciiMetadataValue::from_static("unknown")),
479 );
480
481 GrpcPubSubClient::reconnect_to_server_forever(
482 send_requests,
483 &receiver_input,
484 pubsub_sender,
485 metadata,
486 config,
487 metrics,
488 )
489 .await;
490 },
491 );
492
493 PubSubClientConnection {
494 sender,
495 receiver: Box::new(ReceiverStream::new(receiver_output)),
496 }
497 }
498}
499
500struct GrpcPubSubSender {
502 metrics: Arc<Metrics>,
503 requests: tokio::sync::broadcast::Sender<ProtoPubSubMessage>,
504}
505
506impl Debug for GrpcPubSubSender {
507 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
508 let GrpcPubSubSender {
509 metrics: _metrics,
510 requests: _requests,
511 } = self;
512
513 write!(f, "GrpcPubSubSender")
514 }
515}
516
517fn create_request(message: proto_pub_sub_message::Message) -> ProtoPubSubMessage {
518 let now = SystemTime::now()
519 .duration_since(SystemTime::UNIX_EPOCH)
520 .expect("failed to get millis since epoch");
521
522 ProtoPubSubMessage {
523 timestamp: Some(now.into_proto()),
524 message: Some(message),
525 }
526}
527
528impl GrpcPubSubSender {
529 fn send(&self, message: proto_pub_sub_message::Message, metrics: &PubSubClientCallMetrics) {
530 let size = message.encoded_len();
531
532 match self.requests.send(create_request(message)) {
533 Ok(_) => {
534 metrics.succeeded.inc();
535 metrics.bytes_sent.inc_by(u64::cast_from(size));
536 }
537 Err(err) => {
538 metrics.failed.inc();
539 debug!("error sending client message: {}", err);
540 }
541 }
542 }
543}
544
545impl PubSubSenderInternal for GrpcPubSubSender {
546 fn push_diff(&self, shard_id: &ShardId, diff: &VersionedData) {
547 self.send(
548 proto_pub_sub_message::Message::PushDiff(ProtoPushDiff {
549 shard_id: shard_id.into_proto(),
550 seqno: diff.seqno.into_proto(),
551 diff: diff.data.clone(),
552 }),
553 &self.metrics.pubsub_client.sender.push,
554 )
555 }
556
557 fn subscribe(&self, shard_id: &ShardId) {
558 self.send(
559 proto_pub_sub_message::Message::Subscribe(ProtoSubscribe {
560 shard_id: shard_id.into_proto(),
561 }),
562 &self.metrics.pubsub_client.sender.subscribe,
563 )
564 }
565
566 fn unsubscribe(&self, shard_id: &ShardId) {
567 self.send(
568 proto_pub_sub_message::Message::Unsubscribe(ProtoUnsubscribe {
569 shard_id: shard_id.into_proto(),
570 }),
571 &self.metrics.pubsub_client.sender.unsubscribe,
572 )
573 }
574}
575
576#[derive(Debug)]
579struct SubscriptionTrackingSender {
580 delegate: Arc<dyn PubSubSenderInternal>,
581 subscribes: Arc<Mutex<BTreeMap<ShardId, Weak<ShardSubscriptionToken>>>>,
582}
583
584impl SubscriptionTrackingSender {
585 fn new(sender: Arc<dyn PubSubSenderInternal>) -> Self {
586 Self {
587 delegate: sender,
588 subscribes: Default::default(),
589 }
590 }
591
592 fn subscriptions(&self) -> Vec<ShardId> {
593 let mut subscribes = self.subscribes.lock().expect("lock");
594 let mut out = Vec::with_capacity(subscribes.len());
595 subscribes.retain(|shard_id, token| {
596 if token.upgrade().is_none() {
597 false
598 } else {
599 debug!("reconnecting to: {}", shard_id);
600 out.push(*shard_id);
601 true
602 }
603 });
604 out
605 }
606}
607
608impl PubSubSender for SubscriptionTrackingSender {
609 fn push_diff(&self, shard_id: &ShardId, diff: &VersionedData) {
610 self.delegate.push_diff(shard_id, diff)
611 }
612
613 fn subscribe(self: Arc<Self>, shard_id: &ShardId) -> Arc<ShardSubscriptionToken> {
614 let mut subscribes = self.subscribes.lock().expect("lock");
615 if let Some(token) = subscribes.get(shard_id) {
616 match token.upgrade() {
617 None => assert!(subscribes.remove(shard_id).is_some()),
618 Some(token) => {
619 return Arc::clone(&token);
620 }
621 }
622 }
623
624 let pubsub_sender = Arc::clone(&self.delegate);
625 let token = Arc::new(ShardSubscriptionToken {
626 shard_id: *shard_id,
627 sender: pubsub_sender,
628 });
629
630 assert!(
631 subscribes
632 .insert(*shard_id, Arc::downgrade(&token))
633 .is_none()
634 );
635
636 self.delegate.subscribe(shard_id);
637
638 token
639 }
640}
641
642#[derive(Debug)]
646pub struct MetricsSameProcessPubSubSender {
647 delegate_subscribe: bool,
648 metrics: Arc<Metrics>,
649 delegate: Arc<dyn PubSubSender>,
650}
651
652impl MetricsSameProcessPubSubSender {
653 pub fn new(
656 cfg: &PersistConfig,
657 pubsub_sender: Arc<dyn PubSubSender>,
658 metrics: Arc<Metrics>,
659 ) -> Self {
660 Self {
661 delegate_subscribe: PUBSUB_SAME_PROCESS_DELEGATE_ENABLED.get(cfg),
662 delegate: pubsub_sender,
663 metrics,
664 }
665 }
666}
667
668impl PubSubSender for MetricsSameProcessPubSubSender {
669 fn push_diff(&self, shard_id: &ShardId, diff: &VersionedData) {
670 self.delegate.push_diff(shard_id, diff);
671 self.metrics.pubsub_client.sender.push.succeeded.inc();
672 }
673
674 fn subscribe(self: Arc<Self>, shard_id: &ShardId) -> Arc<ShardSubscriptionToken> {
675 if self.delegate_subscribe {
676 let delegate = Arc::clone(&self.delegate);
677 delegate.subscribe(shard_id)
678 } else {
679 Arc::new(ShardSubscriptionToken {
685 shard_id: *shard_id,
686 sender: Arc::new(NoopPubSubSender),
687 })
688 }
689 }
690}
691
692#[derive(Debug)]
693pub(crate) struct NoopPubSubSender;
694
695impl PubSubSenderInternal for NoopPubSubSender {
696 fn push_diff(&self, _shard_id: &ShardId, _diff: &VersionedData) {}
697 fn subscribe(&self, _shard_id: &ShardId) {}
698 fn unsubscribe(&self, _shard_id: &ShardId) {}
699}
700
701impl PubSubSender for NoopPubSubSender {
702 fn push_diff(&self, _shard_id: &ShardId, _diff: &VersionedData) {}
703
704 fn subscribe(self: Arc<Self>, shard_id: &ShardId) -> Arc<ShardSubscriptionToken> {
705 Arc::new(ShardSubscriptionToken {
706 shard_id: *shard_id,
707 sender: self,
708 })
709 }
710}
711
712pub(crate) fn subscribe_state_cache_to_pubsub(
714 cache: Arc<StateCache>,
715 mut pubsub_receiver: Box<dyn PubSubReceiver>,
716) -> JoinHandle<()> {
717 let mut state_refs: HashMap<ShardId, Weak<dyn DynState>> = HashMap::new();
718 let receiver_metrics = cache.metrics.pubsub_client.receiver.clone();
719
720 mz_ore::task::spawn(
721 || "persist::rpc::client::state_cache_diff_apply",
722 async move {
723 while let Some(msg) = pubsub_receiver.next().await {
724 match msg.message {
725 Some(proto_pub_sub_message::Message::PushDiff(diff)) => {
726 receiver_metrics.push_received.inc();
727 let shard_id = diff.shard_id.into_rust().expect("valid shard id");
728 let diff = VersionedData {
729 seqno: diff.seqno.into_rust().expect("valid SeqNo"),
730 data: diff.diff,
731 };
732 debug!(
733 "applying pubsub diff {} {} {}",
734 shard_id,
735 diff.seqno,
736 diff.data.len()
737 );
738
739 let mut pushed_diff = false;
740 if let Some(state_ref) = state_refs.get(&shard_id) {
741 if let Some(state) = state_ref.upgrade() {
744 state.push_diff(diff.clone());
745 pushed_diff = true;
746 receiver_metrics.state_pushed_diff_fast_path.inc();
747 }
748 }
749
750 if !pushed_diff {
751 let state_ref = cache.get_state_weak(&shard_id);
756 match state_ref {
757 None => {
758 state_refs.remove(&shard_id);
759 }
760 Some(state_ref) => {
761 if let Some(state) = state_ref.upgrade() {
762 state.push_diff(diff);
763 pushed_diff = true;
764 state_refs.insert(shard_id, state_ref);
765 } else {
766 state_refs.remove(&shard_id);
767 }
768 }
769 }
770
771 if pushed_diff {
772 receiver_metrics.state_pushed_diff_slow_path_succeeded.inc();
773 } else {
774 receiver_metrics.state_pushed_diff_slow_path_failed.inc();
775 }
776 }
777
778 if let Some(send_timestamp) = msg.timestamp {
779 let send_timestamp =
780 send_timestamp.into_rust().expect("valid timestamp");
781 let now = SystemTime::now()
782 .duration_since(SystemTime::UNIX_EPOCH)
783 .expect("failed to get millis since epoch");
784 receiver_metrics
785 .approx_diff_latency_seconds
786 .observe((now.saturating_sub(send_timestamp)).as_secs_f64());
787 }
788 }
789 ref msg @ None | ref msg @ Some(_) => {
790 warn!("pubsub client received unexpected message: {:?}", msg);
791 receiver_metrics.unknown_message_received.inc();
792 }
793 }
794 }
795 },
796 )
797}
798
799#[derive(Debug)]
801pub(crate) struct PubSubState {
802 connection_id_counter: AtomicUsize,
804 shard_subscribers:
806 Arc<RwLock<BTreeMap<ShardId, BTreeMap<usize, Sender<Result<ProtoPubSubMessage, Status>>>>>>,
807 connections: Arc<RwLock<HashSet<usize>>>,
809 metrics: Arc<PubSubServerMetrics>,
811}
812
813impl PubSubState {
814 fn new_connection(
815 self: Arc<Self>,
816 notifier: Sender<Result<ProtoPubSubMessage, Status>>,
817 ) -> PubSubConnection {
818 let connection_id = self.connection_id_counter.fetch_add(1, Ordering::SeqCst);
819 {
820 debug!("inserting connid: {}", connection_id);
821 let mut connections = self.connections.write().expect("lock");
822 assert!(connections.insert(connection_id));
823 }
824
825 self.metrics.active_connections.inc();
826 PubSubConnection {
827 connection_id,
828 notifier,
829 state: self,
830 }
831 }
832
833 fn remove_connection(&self, connection_id: usize) {
834 let now = Instant::now();
835
836 {
837 debug!("removing connid: {}", connection_id);
838 let mut connections = self.connections.write().expect("lock");
839 assert!(
840 connections.remove(&connection_id),
841 "unknown connection id: {}",
842 connection_id
843 );
844 }
845
846 {
847 let mut subscribers = self.shard_subscribers.write().expect("lock poisoned");
848 subscribers.retain(|_shard, connections_for_shard| {
849 connections_for_shard.remove(&connection_id);
850 !connections_for_shard.is_empty()
851 });
852 }
853
854 self.metrics
855 .connection_cleanup_seconds
856 .inc_by(now.elapsed().as_secs_f64());
857 self.metrics.active_connections.dec();
858 }
859
860 fn push_diff(&self, connection_id: usize, shard_id: &ShardId, data: &VersionedData) {
861 let now = Instant::now();
862 self.metrics.push_call_count.inc();
863
864 assert!(
865 self.connections
866 .read()
867 .expect("lock")
868 .contains(&connection_id),
869 "unknown connection id: {}",
870 connection_id
871 );
872
873 let subscribers = self.shard_subscribers.read().expect("lock poisoned");
874 if let Some(subscribed_connections) = subscribers.get(shard_id) {
875 let mut num_sent = 0;
876 let mut data_size = 0;
877
878 for (subscribed_conn_id, tx) in subscribed_connections {
879 if *subscribed_conn_id == connection_id {
881 continue;
882 }
883 debug!(
884 "server forwarding req to conn {}: {} {} {}",
885 subscribed_conn_id,
886 &shard_id,
887 data.seqno,
888 data.data.len()
889 );
890 let req = create_request(proto_pub_sub_message::Message::PushDiff(ProtoPushDiff {
891 seqno: data.seqno.into_proto(),
892 shard_id: shard_id.to_string(),
893 diff: Bytes::clone(&data.data),
894 }));
895 data_size = req.encoded_len();
896 match tx.try_send(Ok(req)) {
897 Ok(_) => {
898 num_sent += 1;
899 }
900 Err(TrySendError::Full(_)) => {
901 self.metrics.broadcasted_diff_dropped_channel_full.inc();
902 }
903 Err(TrySendError::Closed(_)) => {}
904 };
905 }
906
907 self.metrics.broadcasted_diff_count.inc_by(num_sent);
908 self.metrics
909 .broadcasted_diff_bytes
910 .inc_by(num_sent * u64::cast_from(data_size));
911 }
912
913 self.metrics
914 .push_seconds
915 .inc_by(now.elapsed().as_secs_f64());
916 }
917
918 fn subscribe(
919 &self,
920 connection_id: usize,
921 notifier: Sender<Result<ProtoPubSubMessage, Status>>,
922 shard_id: &ShardId,
923 ) {
924 let now = Instant::now();
925 self.metrics.subscribe_call_count.inc();
926
927 assert!(
928 self.connections
929 .read()
930 .expect("lock")
931 .contains(&connection_id),
932 "unknown connection id: {}",
933 connection_id
934 );
935
936 {
937 let mut subscribed_shards = self.shard_subscribers.write().expect("lock poisoned");
938 subscribed_shards
939 .entry(*shard_id)
940 .or_default()
941 .insert(connection_id, notifier);
942 }
943
944 self.metrics
945 .subscribe_seconds
946 .inc_by(now.elapsed().as_secs_f64());
947 }
948
949 fn unsubscribe(&self, connection_id: usize, shard_id: &ShardId) {
950 let now = Instant::now();
951 self.metrics.unsubscribe_call_count.inc();
952
953 assert!(
954 self.connections
955 .read()
956 .expect("lock")
957 .contains(&connection_id),
958 "unknown connection id: {}",
959 connection_id
960 );
961
962 {
963 let mut subscribed_shards = self.shard_subscribers.write().expect("lock poisoned");
964 if let Entry::Occupied(mut entry) = subscribed_shards.entry(*shard_id) {
965 let subscribed_connections = entry.get_mut();
966 subscribed_connections.remove(&connection_id);
967
968 if subscribed_connections.is_empty() {
969 entry.remove_entry();
970 }
971 }
972 }
973
974 self.metrics
975 .unsubscribe_seconds
976 .inc_by(now.elapsed().as_secs_f64());
977 }
978
979 #[cfg(test)]
980 fn new_for_test() -> Self {
981 Self {
982 connection_id_counter: AtomicUsize::new(0),
983 shard_subscribers: Default::default(),
984 connections: Default::default(),
985 metrics: Arc::new(PubSubServerMetrics::new(&MetricsRegistry::new())),
986 }
987 }
988
989 #[cfg(test)]
990 fn active_connections(&self) -> HashSet<usize> {
991 self.connections.read().expect("lock").clone()
992 }
993
994 #[cfg(test)]
995 fn subscriptions(&self, connection_id: usize) -> HashSet<ShardId> {
996 let mut shards = HashSet::new();
997
998 let subscribers = self.shard_subscribers.read().expect("lock");
999 for (shard, subscribed_connections) in subscribers.iter() {
1000 if subscribed_connections.contains_key(&connection_id) {
1001 shards.insert(*shard);
1002 }
1003 }
1004
1005 shards
1006 }
1007
1008 #[cfg(test)]
1009 fn shard_subscription_counts(&self) -> mz_ore::collections::HashMap<ShardId, usize> {
1010 let mut shards = mz_ore::collections::HashMap::new();
1011
1012 let subscribers = self.shard_subscribers.read().expect("lock");
1013 for (shard, subscribed_connections) in subscribers.iter() {
1014 shards.insert(*shard, subscribed_connections.len());
1015 }
1016
1017 shards
1018 }
1019}
1020
1021#[derive(Debug)]
1023pub struct PersistGrpcPubSubServer {
1024 cfg: PersistConfig,
1025 state: Arc<PubSubState>,
1026}
1027
1028impl PersistGrpcPubSubServer {
1029 pub fn new(cfg: &PersistConfig, metrics_registry: &MetricsRegistry) -> Self {
1031 let metrics = PubSubServerMetrics::new(metrics_registry);
1032 let state = Arc::new(PubSubState {
1033 connection_id_counter: AtomicUsize::new(0),
1034 shard_subscribers: Default::default(),
1035 connections: Default::default(),
1036 metrics: Arc::new(metrics),
1037 });
1038
1039 PersistGrpcPubSubServer {
1040 cfg: cfg.clone(),
1041 state,
1042 }
1043 }
1044
1045 pub fn new_same_process_connection(&self) -> PubSubClientConnection {
1049 let (tx, rx) =
1050 tokio::sync::mpsc::channel(PUBSUB_CLIENT_RECEIVER_CHANNEL_SIZE.get(&self.cfg));
1051 let sender: Arc<dyn PubSubSender> = Arc::new(SubscriptionTrackingSender::new(Arc::new(
1052 Arc::clone(&self.state).new_connection(tx),
1053 )));
1054
1055 PubSubClientConnection {
1056 sender,
1057 receiver: Box::new(
1058 ReceiverStream::new(rx).map(|x| x.expect("cannot receive grpc errors locally")),
1059 ),
1060 }
1061 }
1062
1063 pub async fn serve(self, listen_addr: SocketAddr) -> Result<(), anyhow::Error> {
1065 tonic::transport::Server::builder()
1067 .add_service(ProtoPersistPubSubServer::new(self).max_decoding_message_size(usize::MAX))
1068 .serve(listen_addr)
1069 .await?;
1070 Ok(())
1071 }
1072
1073 pub async fn serve_with_stream(
1076 self,
1077 listener: tokio_stream::wrappers::TcpListenerStream,
1078 ) -> Result<(), anyhow::Error> {
1079 tonic::transport::Server::builder()
1080 .add_service(ProtoPersistPubSubServer::new(self))
1081 .serve_with_incoming(listener)
1082 .await?;
1083 Ok(())
1084 }
1085}
1086
1087#[async_trait]
1088impl proto_persist_pub_sub_server::ProtoPersistPubSub for PersistGrpcPubSubServer {
1089 type PubSubStream = Pin<Box<dyn Stream<Item = Result<ProtoPubSubMessage, Status>> + Send>>;
1090
1091 #[mz_ore::instrument(name = "persist::rpc::server", level = "info")]
1092 async fn pub_sub(
1093 &self,
1094 request: Request<Streaming<ProtoPubSubMessage>>,
1095 ) -> Result<Response<Self::PubSubStream>, Status> {
1096 let caller_id = request
1097 .metadata()
1098 .get(AsciiMetadataKey::from_static(PERSIST_PUBSUB_CALLER_KEY))
1099 .map(|key| key.to_str().ok())
1100 .flatten()
1101 .map(|key| key.to_string())
1102 .unwrap_or_else(|| "unknown".to_string());
1103 info!("Received Persist PubSub connection from: {:?}", caller_id);
1104
1105 let mut in_stream = request.into_inner();
1106 let (tx, rx) =
1107 tokio::sync::mpsc::channel(PUBSUB_SERVER_CONNECTION_CHANNEL_SIZE.get(&self.cfg));
1108
1109 let caller = caller_id.clone();
1110 let cfg = Arc::clone(&self.cfg.configs);
1111 let server_state = Arc::clone(&self.state);
1112 let connection_span = info_span!("connection", caller_id);
1116 mz_ore::task::spawn(
1117 || format!("persist_pubsub_connection({})", caller),
1118 async move {
1119 let connection = server_state.new_connection(tx);
1120 while let Some(result) = in_stream.next().await {
1121 let req = match result {
1122 Ok(req) => req,
1123 Err(err) => {
1124 warn!("pubsub connection err: {}", err);
1125 break;
1126 }
1127 };
1128
1129 match req.message {
1130 None => {
1131 warn!("received empty message from: {}", caller_id);
1132 }
1133 Some(proto_pub_sub_message::Message::PushDiff(req)) => {
1134 let shard_id = req.shard_id.parse().expect("valid shard id");
1135 let diff = VersionedData {
1136 seqno: req.seqno.into_rust().expect("valid seqno"),
1137 data: req.diff.clone(),
1138 };
1139 if PUBSUB_PUSH_DIFF_ENABLED.get(&cfg) {
1140 connection.push_diff(&shard_id, &diff);
1141 }
1142 }
1143 Some(proto_pub_sub_message::Message::Subscribe(diff)) => {
1144 let shard_id = diff.shard_id.parse().expect("valid shard id");
1145 connection.subscribe(&shard_id);
1146 }
1147 Some(proto_pub_sub_message::Message::Unsubscribe(diff)) => {
1148 let shard_id = diff.shard_id.parse().expect("valid shard id");
1149 connection.unsubscribe(&shard_id);
1150 }
1151 }
1152 }
1153
1154 info!("Persist PubSub connection ended: {:?}", caller_id);
1155 }
1156 .instrument(connection_span),
1157 );
1158
1159 let out_stream: Self::PubSubStream = Box::pin(ReceiverStream::new(rx));
1160 Ok(Response::new(out_stream))
1161 }
1162}
1163
1164#[derive(Debug)]
1168pub(crate) struct PubSubConnection {
1169 connection_id: usize,
1170 notifier: Sender<Result<ProtoPubSubMessage, Status>>,
1171 state: Arc<PubSubState>,
1172}
1173
1174impl PubSubSenderInternal for PubSubConnection {
1175 fn push_diff(&self, shard_id: &ShardId, diff: &VersionedData) {
1176 self.state.push_diff(self.connection_id, shard_id, diff)
1177 }
1178
1179 fn subscribe(&self, shard_id: &ShardId) {
1180 self.state
1181 .subscribe(self.connection_id, self.notifier.clone(), shard_id)
1182 }
1183
1184 fn unsubscribe(&self, shard_id: &ShardId) {
1185 self.state.unsubscribe(self.connection_id, shard_id)
1186 }
1187}
1188
1189impl Drop for PubSubConnection {
1190 fn drop(&mut self) {
1191 self.state.remove_connection(self.connection_id)
1192 }
1193}
1194
1195#[cfg(test)]
1196mod pubsub_state {
1197 use std::str::FromStr;
1198 use std::sync::Arc;
1199 use std::sync::LazyLock;
1200
1201 use bytes::Bytes;
1202 use mz_ore::collections::HashSet;
1203 use mz_persist::location::{SeqNo, VersionedData};
1204 use mz_proto::RustType;
1205 use tokio::sync::mpsc::Receiver;
1206 use tokio::sync::mpsc::error::TryRecvError;
1207 use tonic::Status;
1208
1209 use crate::ShardId;
1210 use crate::internal::service::ProtoPubSubMessage;
1211 use crate::internal::service::proto_pub_sub_message::Message;
1212 use crate::rpc::{PubSubSenderInternal, PubSubState};
1213
1214 static SHARD_ID_0: LazyLock<ShardId> =
1215 LazyLock::new(|| ShardId::from_str("s00000000-0000-0000-0000-000000000000").unwrap());
1216 static SHARD_ID_1: LazyLock<ShardId> =
1217 LazyLock::new(|| ShardId::from_str("s11111111-1111-1111-1111-111111111111").unwrap());
1218
1219 const VERSIONED_DATA_0: VersionedData = VersionedData {
1220 seqno: SeqNo(0),
1221 data: Bytes::from_static(&[0, 1, 2, 3]),
1222 };
1223
1224 const VERSIONED_DATA_1: VersionedData = VersionedData {
1225 seqno: SeqNo(1),
1226 data: Bytes::from_static(&[4, 5, 6, 7]),
1227 };
1228
1229 #[mz_ore::test]
1230 #[should_panic(expected = "unknown connection id: 100")]
1231 fn test_zero_connections_push_diff() {
1232 let state = Arc::new(PubSubState::new_for_test());
1233 state.push_diff(100, &SHARD_ID_0, &VERSIONED_DATA_0);
1234 }
1235
1236 #[mz_ore::test]
1237 #[should_panic(expected = "unknown connection id: 100")]
1238 fn test_zero_connections_subscribe() {
1239 let state = Arc::new(PubSubState::new_for_test());
1240 let (tx, _) = tokio::sync::mpsc::channel(100);
1241 state.subscribe(100, tx, &SHARD_ID_0);
1242 }
1243
1244 #[mz_ore::test]
1245 #[should_panic(expected = "unknown connection id: 100")]
1246 fn test_zero_connections_unsubscribe() {
1247 let state = Arc::new(PubSubState::new_for_test());
1248 state.unsubscribe(100, &SHARD_ID_0);
1249 }
1250
1251 #[mz_ore::test]
1252 #[should_panic(expected = "unknown connection id: 100")]
1253 fn test_zero_connections_remove() {
1254 let state = Arc::new(PubSubState::new_for_test());
1255 state.remove_connection(100)
1256 }
1257
1258 #[mz_ore::test]
1259 fn test_single_connection() {
1260 let state = Arc::new(PubSubState::new_for_test());
1261
1262 let (tx, mut rx) = tokio::sync::mpsc::channel(100);
1263 let connection = Arc::clone(&state).new_connection(tx);
1264
1265 assert_eq!(
1266 state.active_connections(),
1267 HashSet::from([connection.connection_id])
1268 );
1269
1270 assert!(matches!(rx.try_recv(), Err(TryRecvError::Empty)));
1272
1273 connection.push_diff(
1274 &SHARD_ID_0,
1275 &VersionedData {
1276 seqno: SeqNo::minimum(),
1277 data: Bytes::new(),
1278 },
1279 );
1280
1281 assert!(matches!(rx.try_recv(), Err(TryRecvError::Empty)));
1283
1284 connection.subscribe(&SHARD_ID_0);
1286 assert_eq!(
1287 state.subscriptions(connection.connection_id),
1288 HashSet::from([SHARD_ID_0.clone()])
1289 );
1290
1291 connection.unsubscribe(&SHARD_ID_0);
1293 assert!(state.subscriptions(connection.connection_id).is_empty());
1294
1295 connection.subscribe(&SHARD_ID_0);
1297 connection.subscribe(&SHARD_ID_1);
1298 assert_eq!(
1299 state.subscriptions(connection.connection_id),
1300 HashSet::from([*SHARD_ID_0, *SHARD_ID_1])
1301 );
1302
1303 connection.subscribe(&SHARD_ID_0);
1305 connection.subscribe(&SHARD_ID_0);
1306 assert_eq!(
1307 state.subscriptions(connection.connection_id),
1308 HashSet::from([*SHARD_ID_0, *SHARD_ID_1])
1309 );
1310
1311 let connection_id = connection.connection_id;
1313 drop(connection);
1314 assert!(state.subscriptions(connection_id).is_empty());
1315 assert!(state.active_connections().is_empty());
1316 }
1317
1318 #[mz_ore::test]
1319 fn test_many_connection() {
1320 let state = Arc::new(PubSubState::new_for_test());
1321
1322 let (tx1, mut rx1) = tokio::sync::mpsc::channel(100);
1323 let conn1 = Arc::clone(&state).new_connection(tx1);
1324
1325 let (tx2, mut rx2) = tokio::sync::mpsc::channel(100);
1326 let conn2 = Arc::clone(&state).new_connection(tx2);
1327
1328 let (tx3, mut rx3) = tokio::sync::mpsc::channel(100);
1329 let conn3 = Arc::clone(&state).new_connection(tx3);
1330
1331 conn1.subscribe(&SHARD_ID_0);
1332 conn2.subscribe(&SHARD_ID_0);
1333 conn2.subscribe(&SHARD_ID_1);
1334
1335 assert_eq!(
1336 state.active_connections(),
1337 HashSet::from([
1338 conn1.connection_id,
1339 conn2.connection_id,
1340 conn3.connection_id
1341 ])
1342 );
1343
1344 conn3.push_diff(&SHARD_ID_0, &VERSIONED_DATA_0);
1346 assert_push(&mut rx1, &SHARD_ID_0, &VERSIONED_DATA_0);
1347 assert_push(&mut rx2, &SHARD_ID_0, &VERSIONED_DATA_0);
1348 assert!(matches!(rx3.try_recv(), Err(TryRecvError::Empty)));
1349
1350 conn1.push_diff(&SHARD_ID_0, &VERSIONED_DATA_0);
1352 assert!(matches!(rx1.try_recv(), Err(TryRecvError::Empty)));
1353 assert_push(&mut rx2, &SHARD_ID_0, &VERSIONED_DATA_0);
1354 assert!(matches!(rx3.try_recv(), Err(TryRecvError::Empty)));
1355
1356 conn3.push_diff(&SHARD_ID_1, &VERSIONED_DATA_1);
1358 assert!(matches!(rx1.try_recv(), Err(TryRecvError::Empty)));
1359 assert_push(&mut rx2, &SHARD_ID_1, &VERSIONED_DATA_1);
1360 assert!(matches!(rx3.try_recv(), Err(TryRecvError::Empty)));
1361
1362 conn2.unsubscribe(&SHARD_ID_1);
1364 conn3.push_diff(&SHARD_ID_1, &VERSIONED_DATA_1);
1365 assert!(matches!(rx1.try_recv(), Err(TryRecvError::Empty)));
1366 assert!(matches!(rx2.try_recv(), Err(TryRecvError::Empty)));
1367 assert!(matches!(rx3.try_recv(), Err(TryRecvError::Empty)));
1368
1369 let conn1_id = conn1.connection_id;
1371 drop(conn1);
1372 conn3.push_diff(&SHARD_ID_0, &VERSIONED_DATA_0);
1373 assert!(matches!(rx1.try_recv(), Err(TryRecvError::Disconnected)));
1374 assert_push(&mut rx2, &SHARD_ID_0, &VERSIONED_DATA_0);
1375 assert!(matches!(rx3.try_recv(), Err(TryRecvError::Empty)));
1376
1377 assert!(state.subscriptions(conn1_id).is_empty());
1378 assert_eq!(
1379 state.subscriptions(conn2.connection_id),
1380 HashSet::from([*SHARD_ID_0])
1381 );
1382 assert_eq!(state.subscriptions(conn3.connection_id), HashSet::new());
1383 assert_eq!(
1384 state.active_connections(),
1385 HashSet::from([conn2.connection_id, conn3.connection_id])
1386 );
1387 }
1388
1389 fn assert_push(
1390 rx: &mut Receiver<Result<ProtoPubSubMessage, Status>>,
1391 shard: &ShardId,
1392 data: &VersionedData,
1393 ) {
1394 let message = rx
1395 .try_recv()
1396 .expect("message in channel")
1397 .expect("pubsub")
1398 .message
1399 .expect("proto contains message");
1400 match message {
1401 Message::PushDiff(x) => {
1402 assert_eq!(x.shard_id, shard.into_proto());
1403 assert_eq!(x.seqno, data.seqno.into_proto());
1404 assert_eq!(x.diff, data.data);
1405 }
1406 Message::Subscribe(_) | Message::Unsubscribe(_) => panic!("unexpected message type"),
1407 };
1408 }
1409}
1410
1411#[cfg(test)]
1412mod grpc {
1413 use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
1414 use std::str::FromStr;
1415 use std::sync::Arc;
1416 use std::time::{Duration, Instant};
1417
1418 use bytes::Bytes;
1419 use futures_util::FutureExt;
1420 use mz_dyncfg::ConfigUpdates;
1421 use mz_ore::assert_none;
1422 use mz_ore::collections::HashMap;
1423 use mz_ore::metrics::MetricsRegistry;
1424 use mz_persist::location::{SeqNo, VersionedData};
1425 use mz_proto::RustType;
1426 use std::sync::LazyLock;
1427 use tokio::net::TcpListener;
1428 use tokio_stream::StreamExt;
1429 use tokio_stream::wrappers::TcpListenerStream;
1430
1431 use crate::ShardId;
1432 use crate::cfg::PersistConfig;
1433 use crate::internal::service::ProtoPubSubMessage;
1434 use crate::internal::service::proto_pub_sub_message::Message;
1435 use crate::metrics::Metrics;
1436 use crate::rpc::{
1437 GrpcPubSubClient, PUBSUB_CLIENT_ENABLED, PUBSUB_RECONNECT_BACKOFF, PersistGrpcPubSubServer,
1438 PersistPubSubClient, PersistPubSubClientConfig, PubSubState,
1439 };
1440
1441 static SHARD_ID_0: LazyLock<ShardId> =
1442 LazyLock::new(|| ShardId::from_str("s00000000-0000-0000-0000-000000000000").unwrap());
1443 static SHARD_ID_1: LazyLock<ShardId> =
1444 LazyLock::new(|| ShardId::from_str("s11111111-1111-1111-1111-111111111111").unwrap());
1445 const VERSIONED_DATA_0: VersionedData = VersionedData {
1446 seqno: SeqNo(0),
1447 data: Bytes::from_static(&[0, 1, 2, 3]),
1448 };
1449 const VERSIONED_DATA_1: VersionedData = VersionedData {
1450 seqno: SeqNo(1),
1451 data: Bytes::from_static(&[4, 5, 6, 7]),
1452 };
1453
1454 const CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
1455 const SUBSCRIPTIONS_TIMEOUT: Duration = Duration::from_secs(3);
1456 const SERVER_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(2);
1457
1458 #[mz_ore::test]
1463 #[cfg_attr(miri, ignore)] fn grpc_server() {
1465 let metrics = Arc::new(Metrics::new(
1466 &test_persist_config(),
1467 &MetricsRegistry::new(),
1468 ));
1469 let server_runtime = tokio::runtime::Runtime::new().expect("server runtime");
1470 let client_runtime = tokio::runtime::Runtime::new().expect("client runtime");
1471
1472 let (addr, tcp_listener_stream) = server_runtime.block_on(new_tcp_listener());
1474 let server_state = server_runtime.block_on(spawn_server(tcp_listener_stream));
1475
1476 {
1478 let _guard = client_runtime.enter();
1479 mz_ore::task::spawn(|| "client".to_string(), async move {
1480 let client = GrpcPubSubClient::connect(
1481 PersistPubSubClientConfig {
1482 url: format!("http://{}", addr),
1483 caller_id: "client".to_string(),
1484 persist_cfg: test_persist_config(),
1485 },
1486 metrics,
1487 );
1488 let _token = client.sender.subscribe(&SHARD_ID_0);
1489 tokio::time::sleep(Duration::MAX).await;
1490 });
1491 }
1492
1493 server_runtime.block_on(async {
1495 poll_until_true(CONNECT_TIMEOUT, || {
1496 server_state.active_connections().len() == 1
1497 })
1498 .await;
1499 poll_until_true(SUBSCRIPTIONS_TIMEOUT, || {
1500 server_state.shard_subscription_counts() == HashMap::from([(*SHARD_ID_0, 1)])
1501 })
1502 .await
1503 });
1504
1505 client_runtime.shutdown_timeout(SERVER_SHUTDOWN_TIMEOUT);
1507
1508 server_runtime.block_on(async {
1510 poll_until_true(CONNECT_TIMEOUT, || {
1511 server_state.active_connections().is_empty()
1512 })
1513 .await;
1514 poll_until_true(SUBSCRIPTIONS_TIMEOUT, || {
1515 server_state.shard_subscription_counts() == HashMap::new()
1516 })
1517 .await
1518 });
1519 }
1520
1521 #[mz_ore::test]
1522 #[cfg_attr(miri, ignore)] fn grpc_client_sender_reconnects() {
1524 let metrics = Arc::new(Metrics::new(
1525 &test_persist_config(),
1526 &MetricsRegistry::new(),
1527 ));
1528 let server_runtime = tokio::runtime::Runtime::new().expect("server runtime");
1529 let client_runtime = tokio::runtime::Runtime::new().expect("client runtime");
1530 let (addr, tcp_listener_stream) = server_runtime.block_on(new_tcp_listener());
1531
1532 let client = client_runtime.block_on(async {
1534 GrpcPubSubClient::connect(
1535 PersistPubSubClientConfig {
1536 url: format!("http://{}", addr),
1537 caller_id: "client".to_string(),
1538 persist_cfg: test_persist_config(),
1539 },
1540 metrics,
1541 )
1542 });
1543
1544 let _token = Arc::clone(&client.sender).subscribe(&SHARD_ID_0);
1546 let _token_2 = Arc::clone(&client.sender).subscribe(&SHARD_ID_1);
1548 drop(_token_2);
1549
1550 let server_state = server_runtime.block_on(spawn_server(tcp_listener_stream));
1552
1553 server_runtime.block_on(async {
1554 poll_until_true(CONNECT_TIMEOUT, || {
1556 server_state.active_connections().len() == 1
1557 })
1558 .await;
1559
1560 poll_until_true(SUBSCRIPTIONS_TIMEOUT, || {
1563 server_state.shard_subscription_counts() == HashMap::from([(*SHARD_ID_0, 1)])
1564 })
1565 .await;
1566 });
1567
1568 server_runtime.shutdown_timeout(SERVER_SHUTDOWN_TIMEOUT);
1570
1571 let _token_2 = Arc::clone(&client.sender).subscribe(&SHARD_ID_1);
1573
1574 let server_runtime = tokio::runtime::Runtime::new().expect("server runtime");
1576 let tcp_listener_stream = server_runtime.block_on(async {
1577 TcpListenerStream::new(
1578 TcpListener::bind(addr)
1579 .await
1580 .expect("can bind to previous addr"),
1581 )
1582 });
1583 let server_state = server_runtime.block_on(spawn_server(tcp_listener_stream));
1584
1585 server_runtime.block_on(async {
1586 poll_until_true(CONNECT_TIMEOUT, || {
1588 server_state.active_connections().len() == 1
1589 })
1590 .await;
1591
1592 poll_until_true(SUBSCRIPTIONS_TIMEOUT, || {
1595 server_state.shard_subscription_counts()
1596 == HashMap::from([(*SHARD_ID_0, 1), (*SHARD_ID_1, 1)])
1597 })
1598 .await;
1599 });
1600 }
1601
1602 #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
1603 #[cfg_attr(miri, ignore)] async fn grpc_client_sender_subscription_tokens() {
1605 let metrics = Arc::new(Metrics::new(
1606 &test_persist_config(),
1607 &MetricsRegistry::new(),
1608 ));
1609
1610 let (addr, tcp_listener_stream) = new_tcp_listener().await;
1611 let server_state = spawn_server(tcp_listener_stream).await;
1612
1613 let client = GrpcPubSubClient::connect(
1614 PersistPubSubClientConfig {
1615 url: format!("http://{}", addr),
1616 caller_id: "client".to_string(),
1617 persist_cfg: test_persist_config(),
1618 },
1619 metrics,
1620 );
1621
1622 poll_until_true(CONNECT_TIMEOUT, || {
1624 server_state.active_connections().len() == 1
1625 })
1626 .await;
1627
1628 let token = Arc::clone(&client.sender).subscribe(&SHARD_ID_0);
1630 poll_until_true(SUBSCRIPTIONS_TIMEOUT, || {
1631 server_state.shard_subscription_counts() == HashMap::from([(*SHARD_ID_0, 1)])
1632 })
1633 .await;
1634
1635 drop(token);
1637 poll_until_true(SUBSCRIPTIONS_TIMEOUT, || {
1638 server_state.shard_subscription_counts() == HashMap::new()
1639 })
1640 .await;
1641
1642 let token = Arc::clone(&client.sender).subscribe(&SHARD_ID_0);
1644 poll_until_true(SUBSCRIPTIONS_TIMEOUT, || {
1645 server_state.shard_subscription_counts() == HashMap::from([(*SHARD_ID_0, 1)])
1646 })
1647 .await;
1648
1649 let token2 = Arc::clone(&client.sender).subscribe(&SHARD_ID_0);
1651 let token3 = Arc::clone(&client.sender).subscribe(&SHARD_ID_0);
1652 assert_eq!(Arc::strong_count(&token), 3);
1653 poll_until_true(SUBSCRIPTIONS_TIMEOUT, || {
1654 server_state.shard_subscription_counts() == HashMap::from([(*SHARD_ID_0, 1)])
1655 })
1656 .await;
1657
1658 drop(token);
1660 drop(token2);
1661 drop(token3);
1662 poll_until_true(SUBSCRIPTIONS_TIMEOUT, || {
1663 server_state.shard_subscription_counts() == HashMap::new()
1664 })
1665 .await;
1666
1667 let _token0 = Arc::clone(&client.sender).subscribe(&SHARD_ID_0);
1669 let _token1 = Arc::clone(&client.sender).subscribe(&SHARD_ID_1);
1670 poll_until_true(SUBSCRIPTIONS_TIMEOUT, || {
1671 server_state.shard_subscription_counts()
1672 == HashMap::from([(*SHARD_ID_0, 1), (*SHARD_ID_1, 1)])
1673 })
1674 .await;
1675 }
1676
1677 #[mz_ore::test]
1678 #[cfg_attr(miri, ignore)] fn grpc_client_receiver() {
1680 let metrics = Arc::new(Metrics::new(
1681 &PersistConfig::new_for_tests(),
1682 &MetricsRegistry::new(),
1683 ));
1684 let server_runtime = tokio::runtime::Runtime::new().expect("server runtime");
1685 let client_runtime = tokio::runtime::Runtime::new().expect("client runtime");
1686 let (addr, tcp_listener_stream) = server_runtime.block_on(new_tcp_listener());
1687
1688 let mut client_1 = client_runtime.block_on(async {
1690 GrpcPubSubClient::connect(
1691 PersistPubSubClientConfig {
1692 url: format!("http://{}", addr),
1693 caller_id: "client_1".to_string(),
1694 persist_cfg: test_persist_config(),
1695 },
1696 Arc::clone(&metrics),
1697 )
1698 });
1699 let mut client_2 = client_runtime.block_on(async {
1700 GrpcPubSubClient::connect(
1701 PersistPubSubClientConfig {
1702 url: format!("http://{}", addr),
1703 caller_id: "client_2".to_string(),
1704 persist_cfg: test_persist_config(),
1705 },
1706 metrics,
1707 )
1708 });
1709
1710 assert_none!(client_1.receiver.next().now_or_never());
1715 assert_none!(client_2.receiver.next().now_or_never());
1716
1717 let server_state = server_runtime.block_on(spawn_server(tcp_listener_stream));
1719
1720 server_runtime.block_on(poll_until_true(CONNECT_TIMEOUT, || {
1722 server_state.active_connections().len() == 2
1723 }));
1724
1725 assert_none!(client_1.receiver.next().now_or_never());
1727 assert_none!(client_2.receiver.next().now_or_never());
1728
1729 let _token_client_1 = Arc::clone(&client_1.sender).subscribe(&SHARD_ID_0);
1731 let _token_client_2 = Arc::clone(&client_2.sender).subscribe(&SHARD_ID_0);
1732 server_runtime.block_on(poll_until_true(SUBSCRIPTIONS_TIMEOUT, || {
1733 server_state.shard_subscription_counts() == HashMap::from([(*SHARD_ID_0, 2)])
1734 }));
1735
1736 client_1.sender.push_diff(&SHARD_ID_0, &VERSIONED_DATA_1);
1738 assert_none!(client_1.receiver.next().now_or_never());
1739 client_runtime.block_on(async {
1740 assert_push(
1741 client_2.receiver.next().await.expect("has diff"),
1742 &SHARD_ID_0,
1743 &VERSIONED_DATA_1,
1744 )
1745 });
1746
1747 server_runtime.shutdown_timeout(SERVER_SHUTDOWN_TIMEOUT);
1749
1750 assert_none!(client_1.receiver.next().now_or_never());
1752 assert_none!(client_2.receiver.next().now_or_never());
1753
1754 let server_runtime = tokio::runtime::Runtime::new().expect("server runtime");
1756 let tcp_listener_stream = server_runtime.block_on(async {
1757 TcpListenerStream::new(
1758 TcpListener::bind(addr)
1759 .await
1760 .expect("can bind to previous addr"),
1761 )
1762 });
1763 let server_state = server_runtime.block_on(spawn_server(tcp_listener_stream));
1764
1765 server_runtime.block_on(async {
1767 poll_until_true(CONNECT_TIMEOUT, || {
1768 server_state.active_connections().len() == 2
1769 })
1770 .await;
1771 poll_until_true(SUBSCRIPTIONS_TIMEOUT, || {
1772 server_state.shard_subscription_counts() == HashMap::from([(*SHARD_ID_0, 2)])
1773 })
1774 .await;
1775 });
1776
1777 client_2.sender.push_diff(&SHARD_ID_0, &VERSIONED_DATA_0);
1780 client_runtime.block_on(async {
1781 assert_push(
1782 client_1.receiver.next().await.expect("has diff"),
1783 &SHARD_ID_0,
1784 &VERSIONED_DATA_0,
1785 )
1786 });
1787 assert_none!(client_2.receiver.next().now_or_never());
1788 }
1789
1790 async fn new_tcp_listener() -> (SocketAddr, TcpListenerStream) {
1791 let addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0));
1792 let tcp_listener = TcpListener::bind(addr).await.expect("tcp listener");
1793
1794 (
1795 tcp_listener.local_addr().expect("bound to local address"),
1796 TcpListenerStream::new(tcp_listener),
1797 )
1798 }
1799
1800 #[allow(clippy::unused_async)]
1801 async fn spawn_server(tcp_listener_stream: TcpListenerStream) -> Arc<PubSubState> {
1802 let server = PersistGrpcPubSubServer::new(&test_persist_config(), &MetricsRegistry::new());
1803 let server_state = Arc::clone(&server.state);
1804
1805 let _server_task = mz_ore::task::spawn(|| "server".to_string(), async move {
1806 server.serve_with_stream(tcp_listener_stream).await
1807 });
1808 server_state
1809 }
1810
1811 async fn poll_until_true<F>(timeout: Duration, f: F)
1812 where
1813 F: Fn() -> bool,
1814 {
1815 let now = Instant::now();
1816 loop {
1817 if f() {
1818 return;
1819 }
1820
1821 if now.elapsed() > timeout {
1822 panic!("timed out");
1823 }
1824
1825 tokio::time::sleep(Duration::from_millis(1)).await;
1826 }
1827 }
1828
1829 fn assert_push(message: ProtoPubSubMessage, shard: &ShardId, data: &VersionedData) {
1830 let message = message.message.expect("proto contains message");
1831 match message {
1832 Message::PushDiff(x) => {
1833 assert_eq!(x.shard_id, shard.into_proto());
1834 assert_eq!(x.seqno, data.seqno.into_proto());
1835 assert_eq!(x.diff, data.data);
1836 }
1837 Message::Subscribe(_) | Message::Unsubscribe(_) => panic!("unexpected message type"),
1838 };
1839 }
1840
1841 fn test_persist_config() -> PersistConfig {
1842 let cfg = PersistConfig::new_for_tests();
1843
1844 let mut updates = ConfigUpdates::default();
1845 updates.add(&PUBSUB_CLIENT_ENABLED, true);
1846 updates.add(&PUBSUB_RECONNECT_BACKOFF, Duration::ZERO);
1847 cfg.apply_from(&updates);
1848
1849 cfg
1850 }
1851}