1use std::collections::BTreeMap;
13use std::collections::btree_map::Entry;
14use std::fmt::{Debug, Formatter};
15use std::net::SocketAddr;
16use std::pin::Pin;
17use std::str::FromStr;
18use std::sync::atomic::{AtomicUsize, Ordering};
19use std::sync::{Arc, Mutex, RwLock, Weak};
20use std::time::{Duration, Instant, SystemTime};
21
22use anyhow::{Error, anyhow};
23use async_trait::async_trait;
24use bytes::Bytes;
25use futures::Stream;
26use futures_util::StreamExt;
27use mz_dyncfg::Config;
28use mz_ore::cast::CastFrom;
29use mz_ore::collections::{HashMap, HashSet};
30use mz_ore::metrics::MetricsRegistry;
31use mz_ore::retry::RetryResult;
32use mz_ore::task::JoinHandle;
33use mz_persist::location::VersionedData;
34use mz_proto::{ProtoType, RustType};
35use prost::Message;
36use tokio::sync::mpsc::Sender;
37use tokio::sync::mpsc::error::TrySendError;
38use tokio_stream::wrappers::errors::BroadcastStreamRecvError;
39use tokio_stream::wrappers::{BroadcastStream, ReceiverStream};
40use tonic::metadata::{AsciiMetadataKey, AsciiMetadataValue, MetadataMap};
41use tonic::transport::Endpoint;
42use tonic::{Extensions, Request, Response, Status, Streaming};
43use tracing::{Instrument, debug, error, info, info_span, warn};
44
45use crate::ShardId;
46use crate::cache::{DynState, StateCache};
47use crate::cfg::PersistConfig;
48use crate::internal::metrics::{PubSubClientCallMetrics, PubSubServerMetrics};
49use crate::internal::service::proto_persist_pub_sub_client::ProtoPersistPubSubClient;
50use crate::internal::service::proto_persist_pub_sub_server::ProtoPersistPubSubServer;
51use crate::internal::service::{
52 ProtoPubSubMessage, ProtoPushDiff, ProtoSubscribe, ProtoUnsubscribe,
53 proto_persist_pub_sub_server, proto_pub_sub_message,
54};
55use crate::metrics::Metrics;
56
57pub(crate) const PUBSUB_CLIENT_ENABLED: Config<bool> = Config::new(
59 "persist_pubsub_client_enabled",
60 true,
61 "Whether to connect to the Persist PubSub service.",
62);
63
64pub(crate) const PUBSUB_PUSH_DIFF_ENABLED: Config<bool> = Config::new(
68 "persist_pubsub_push_diff_enabled",
69 true,
70 "Whether to push state diffs to Persist PubSub.",
71);
72
73pub(crate) const PUBSUB_SAME_PROCESS_DELEGATE_ENABLED: Config<bool> = Config::new(
77 "persist_pubsub_same_process_delegate_enabled",
78 true,
79 "Whether to push state diffs to Persist PubSub on the same process.",
80);
81
82pub(crate) const PUBSUB_CONNECT_ATTEMPT_TIMEOUT: Config<Duration> = Config::new(
84 "persist_pubsub_connect_attempt_timeout",
85 Duration::from_secs(5),
86 "Timeout per connection attempt to Persist PubSub service.",
87);
88
89pub(crate) const PUBSUB_REQUEST_TIMEOUT: Config<Duration> = Config::new(
91 "persist_pubsub_request_timeout",
92 Duration::from_secs(5),
93 "Timeout per request attempt to Persist PubSub service.",
94);
95
96pub(crate) const PUBSUB_CONNECT_MAX_BACKOFF: Config<Duration> = Config::new(
98 "persist_pubsub_connect_max_backoff",
99 Duration::from_secs(60),
100 "Maximum backoff when retrying connection establishment to Persist PubSub service.",
101);
102
103pub(crate) const PUBSUB_CLIENT_SENDER_CHANNEL_SIZE: Config<usize> = Config::new(
105 "persist_pubsub_client_sender_channel_size",
106 25,
107 "Size of channel used to buffer send messages to PubSub service.",
108);
109
110pub(crate) const PUBSUB_CLIENT_RECEIVER_CHANNEL_SIZE: Config<usize> = Config::new(
112 "persist_pubsub_client_receiver_channel_size",
113 25,
114 "Size of channel used to buffer received messages from PubSub service.",
115);
116
117pub(crate) const PUBSUB_SERVER_CONNECTION_CHANNEL_SIZE: Config<usize> = Config::new(
119 "persist_pubsub_server_connection_channel_size",
120 25,
121 "Size of channel used per connection to buffer broadcasted messages from PubSub server.",
122);
123
124pub(crate) const PUBSUB_STATE_CACHE_SHARD_REF_CHANNEL_SIZE: Config<usize> = Config::new(
126 "persist_pubsub_state_cache_shard_ref_channel_size",
127 25,
128 "Size of channel used by the state cache to broadcast shard state references.",
129);
130
131pub(crate) const PUBSUB_RECONNECT_BACKOFF: Config<Duration> = Config::new(
133 "persist_pubsub_reconnect_backoff",
134 Duration::from_secs(5),
135 "Backoff after an established connection to Persist PubSub service fails.",
136);
137
138pub trait PersistPubSubClient {
143 fn connect(
145 pubsub_config: PersistPubSubClientConfig,
146 metrics: Arc<Metrics>,
147 ) -> PubSubClientConnection;
148}
149
150#[derive(Debug)]
152pub struct PubSubClientConnection {
153 pub sender: Arc<dyn PubSubSender>,
155 pub receiver: Box<dyn PubSubReceiver>,
157}
158
159impl PubSubClientConnection {
160 pub fn new(sender: Arc<dyn PubSubSender>, receiver: Box<dyn PubSubReceiver>) -> Self {
162 Self { sender, receiver }
163 }
164
165 pub fn noop() -> Self {
167 Self {
168 sender: Arc::new(NoopPubSubSender),
169 receiver: Box::new(futures::stream::empty()),
170 }
171 }
172}
173
174pub trait PubSubSender: std::fmt::Debug + Send + Sync {
176 fn push_diff(&self, shard_id: &ShardId, diff: &VersionedData);
178
179 fn subscribe(self: Arc<Self>, shard_id: &ShardId) -> Arc<ShardSubscriptionToken>;
187}
188
189trait PubSubSenderInternal: Debug + Send + Sync {
194 fn push_diff(&self, shard_id: &ShardId, diff: &VersionedData);
196
197 fn subscribe(&self, shard_id: &ShardId);
201
202 fn unsubscribe(&self, shard_id: &ShardId);
206}
207
208pub trait PubSubReceiver:
213 Stream<Item = ProtoPubSubMessage> + Send + Unpin + std::fmt::Debug
214{
215}
216
217impl<T> PubSubReceiver for T where
218 T: Stream<Item = ProtoPubSubMessage> + Send + Unpin + std::fmt::Debug
219{
220}
221
222pub struct ShardSubscriptionToken {
227 pub(crate) shard_id: ShardId,
228 sender: Arc<dyn PubSubSenderInternal>,
229}
230
231impl Debug for ShardSubscriptionToken {
232 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
233 let ShardSubscriptionToken {
234 shard_id,
235 sender: _sender,
236 } = self;
237 write!(f, "ShardSubscriptionToken({})", shard_id)
238 }
239}
240
241impl Drop for ShardSubscriptionToken {
242 fn drop(&mut self) {
243 self.sender.unsubscribe(&self.shard_id);
244 }
245}
246
247pub const PERSIST_PUBSUB_CALLER_KEY: &str = "persist-pubsub-caller-id";
249
250#[derive(Debug)]
252pub struct PersistPubSubClientConfig {
253 pub url: String,
255 pub caller_id: String,
257 pub persist_cfg: PersistConfig,
259}
260
261#[derive(Debug)]
267pub struct GrpcPubSubClient;
268
269impl GrpcPubSubClient {
270 async fn reconnect_to_server_forever(
271 send_requests: tokio::sync::broadcast::Sender<ProtoPubSubMessage>,
272 receiver_input: &tokio::sync::mpsc::Sender<ProtoPubSubMessage>,
273 sender: Arc<SubscriptionTrackingSender>,
274 metadata: MetadataMap,
275 config: PersistPubSubClientConfig,
276 metrics: Arc<Metrics>,
277 ) {
278 config.persist_cfg.configs_synced_once().await;
283
284 let mut is_first_connection_attempt = true;
285 loop {
286 let sender = Arc::clone(&sender);
287 metrics.pubsub_client.grpc_connection.connected.set(0);
288
289 if !PUBSUB_CLIENT_ENABLED.get(&config.persist_cfg) {
290 tokio::time::sleep(Duration::from_secs(5)).await;
291 continue;
292 }
293
294 if is_first_connection_attempt {
296 is_first_connection_attempt = false;
297 } else {
298 tokio::time::sleep(PUBSUB_RECONNECT_BACKOFF.get(&config.persist_cfg)).await;
299 }
300
301 info!("Connecting to Persist PubSub: {}", config.url);
302 let client = mz_ore::retry::Retry::default()
303 .clamp_backoff(PUBSUB_CONNECT_MAX_BACKOFF.get(&config.persist_cfg))
304 .retry_async(|_| async {
305 metrics
306 .pubsub_client
307 .grpc_connection
308 .connect_call_attempt_count
309 .inc();
310 let endpoint = match Endpoint::from_str(&config.url) {
311 Ok(endpoint) => endpoint,
312 Err(err) => return RetryResult::FatalErr(err),
313 };
314 ProtoPersistPubSubClient::connect(
315 endpoint
316 .connect_timeout(
317 PUBSUB_CONNECT_ATTEMPT_TIMEOUT.get(&config.persist_cfg),
318 )
319 .timeout(PUBSUB_REQUEST_TIMEOUT.get(&config.persist_cfg)),
320 )
321 .await
322 .into()
323 })
324 .await;
325
326 let mut client = match client {
327 Ok(client) => client,
328 Err(err) => {
329 error!("fatal error connecting to persist pubsub: {:?}", err);
330 return;
331 }
332 };
333
334 metrics
335 .pubsub_client
336 .grpc_connection
337 .connection_established_count
338 .inc();
339 metrics.pubsub_client.grpc_connection.connected.set(1);
340
341 info!("Connected to Persist PubSub: {}", config.url);
342
343 let mut broadcast = BroadcastStream::new(send_requests.subscribe());
344 let broadcast_errors = metrics
345 .pubsub_client
346 .grpc_connection
347 .broadcast_recv_lagged_count
348 .clone();
349
350 let broadcast_messages = async_stream::stream! {
354 'reconnect: loop {
355 for id in sender.subscriptions() {
357 debug!("re-subscribing to shard: {id}");
358 yield create_request(proto_pub_sub_message::Message::Subscribe(ProtoSubscribe {
359 shard_id: id.into_proto(),
360 }));
361 }
362
363 while let Some(message) = broadcast.next().await {
365 debug!("sending pubsub message: {:?}", message);
366 match message {
367 Ok(message) => yield message,
368 Err(BroadcastStreamRecvError::Lagged(i)) => {
369 broadcast_errors.inc_by(i);
370 continue 'reconnect;
371 }
372 }
373 }
374 debug!("exhausted pubsub broadcast stream; shutting down");
375 break;
376 }
377 };
378 let pubsub_request =
379 Request::from_parts(metadata.clone(), Extensions::default(), broadcast_messages);
380
381 let responses = match client.pub_sub(pubsub_request).await {
382 Ok(response) => response.into_inner(),
383 Err(err) => {
384 warn!("pub_sub rpc error: {:?}", err);
385 continue;
386 }
387 };
388
389 let stream_completed = GrpcPubSubClient::consume_grpc_stream(
390 responses,
391 receiver_input,
392 &config,
393 metrics.as_ref(),
394 )
395 .await;
396
397 match stream_completed {
398 Ok(_) => continue,
400 Err(err) => {
403 warn!("shutting down connection loop to Persist PubSub: {}", err);
404 return;
405 }
406 }
407 }
408 }
409
410 async fn consume_grpc_stream(
411 mut responses: Streaming<ProtoPubSubMessage>,
412 receiver_input: &Sender<ProtoPubSubMessage>,
413 config: &PersistPubSubClientConfig,
414 metrics: &Metrics,
415 ) -> Result<(), Error> {
416 loop {
417 if !PUBSUB_CLIENT_ENABLED.get(&config.persist_cfg) {
418 return Ok(());
419 }
420
421 debug!("awaiting next pubsub response");
422 match responses.next().await {
423 Some(Ok(message)) => {
424 debug!("received pubsub message: {:?}", message);
425 match receiver_input.send(message).await {
426 Ok(_) => {}
427 Err(err) => {
430 return Err(anyhow!("closing pubsub grpc client connection: {}", err));
431 }
432 }
433 }
434 Some(Err(err)) => {
435 metrics.pubsub_client.grpc_connection.grpc_error_count.inc();
436 warn!("pubsub client error: {:?}", err);
437 return Ok(());
438 }
439 None => return Ok(()),
440 }
441 }
442 }
443}
444
445impl PersistPubSubClient for GrpcPubSubClient {
446 fn connect(config: PersistPubSubClientConfig, metrics: Arc<Metrics>) -> PubSubClientConnection {
447 let (send_requests, _) = tokio::sync::broadcast::channel(
452 PUBSUB_CLIENT_SENDER_CHANNEL_SIZE.get(&config.persist_cfg),
453 );
454 let (receiver_input, receiver_output) = tokio::sync::mpsc::channel(
458 PUBSUB_CLIENT_RECEIVER_CHANNEL_SIZE.get(&config.persist_cfg),
459 );
460
461 let sender = Arc::new(SubscriptionTrackingSender::new(Arc::new(
462 GrpcPubSubSender {
463 metrics: Arc::clone(&metrics),
464 requests: send_requests.clone(),
465 },
466 )));
467 let pubsub_sender = Arc::clone(&sender);
468 mz_ore::task::spawn(
469 || "persist::rpc::client::connection".to_string(),
470 async move {
471 let mut metadata = MetadataMap::new();
472 metadata.insert(
473 AsciiMetadataKey::from_static(PERSIST_PUBSUB_CALLER_KEY),
474 AsciiMetadataValue::try_from(&config.caller_id)
475 .unwrap_or_else(|_| AsciiMetadataValue::from_static("unknown")),
476 );
477
478 GrpcPubSubClient::reconnect_to_server_forever(
479 send_requests,
480 &receiver_input,
481 pubsub_sender,
482 metadata,
483 config,
484 metrics,
485 )
486 .await;
487 },
488 );
489
490 PubSubClientConnection {
491 sender,
492 receiver: Box::new(ReceiverStream::new(receiver_output)),
493 }
494 }
495}
496
497struct GrpcPubSubSender {
499 metrics: Arc<Metrics>,
500 requests: tokio::sync::broadcast::Sender<ProtoPubSubMessage>,
501}
502
503impl Debug for GrpcPubSubSender {
504 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
505 let GrpcPubSubSender {
506 metrics: _metrics,
507 requests: _requests,
508 } = self;
509
510 write!(f, "GrpcPubSubSender")
511 }
512}
513
514fn create_request(message: proto_pub_sub_message::Message) -> ProtoPubSubMessage {
515 let now = SystemTime::now()
516 .duration_since(SystemTime::UNIX_EPOCH)
517 .expect("failed to get millis since epoch");
518
519 ProtoPubSubMessage {
520 timestamp: Some(now.into_proto()),
521 message: Some(message),
522 }
523}
524
525impl GrpcPubSubSender {
526 fn send(&self, message: proto_pub_sub_message::Message, metrics: &PubSubClientCallMetrics) {
527 let size = message.encoded_len();
528
529 match self.requests.send(create_request(message)) {
530 Ok(_) => {
531 metrics.succeeded.inc();
532 metrics.bytes_sent.inc_by(u64::cast_from(size));
533 }
534 Err(err) => {
535 metrics.failed.inc();
536 debug!("error sending client message: {}", err);
537 }
538 }
539 }
540}
541
542impl PubSubSenderInternal for GrpcPubSubSender {
543 fn push_diff(&self, shard_id: &ShardId, diff: &VersionedData) {
544 self.send(
545 proto_pub_sub_message::Message::PushDiff(ProtoPushDiff {
546 shard_id: shard_id.into_proto(),
547 seqno: diff.seqno.into_proto(),
548 diff: diff.data.clone(),
549 }),
550 &self.metrics.pubsub_client.sender.push,
551 )
552 }
553
554 fn subscribe(&self, shard_id: &ShardId) {
555 self.send(
556 proto_pub_sub_message::Message::Subscribe(ProtoSubscribe {
557 shard_id: shard_id.into_proto(),
558 }),
559 &self.metrics.pubsub_client.sender.subscribe,
560 )
561 }
562
563 fn unsubscribe(&self, shard_id: &ShardId) {
564 self.send(
565 proto_pub_sub_message::Message::Unsubscribe(ProtoUnsubscribe {
566 shard_id: shard_id.into_proto(),
567 }),
568 &self.metrics.pubsub_client.sender.unsubscribe,
569 )
570 }
571}
572
573#[derive(Debug)]
576struct SubscriptionTrackingSender {
577 delegate: Arc<dyn PubSubSenderInternal>,
578 subscribes: Arc<Mutex<BTreeMap<ShardId, Weak<ShardSubscriptionToken>>>>,
579}
580
581impl SubscriptionTrackingSender {
582 fn new(sender: Arc<dyn PubSubSenderInternal>) -> Self {
583 Self {
584 delegate: sender,
585 subscribes: Default::default(),
586 }
587 }
588
589 fn subscriptions(&self) -> Vec<ShardId> {
590 let mut subscribes = self.subscribes.lock().expect("lock");
591 let mut out = Vec::with_capacity(subscribes.len());
592 subscribes.retain(|shard_id, token| {
593 if token.upgrade().is_none() {
594 false
595 } else {
596 debug!("reconnecting to: {}", shard_id);
597 out.push(*shard_id);
598 true
599 }
600 });
601 out
602 }
603}
604
605impl PubSubSender for SubscriptionTrackingSender {
606 fn push_diff(&self, shard_id: &ShardId, diff: &VersionedData) {
607 self.delegate.push_diff(shard_id, diff)
608 }
609
610 fn subscribe(self: Arc<Self>, shard_id: &ShardId) -> Arc<ShardSubscriptionToken> {
611 let mut subscribes = self.subscribes.lock().expect("lock");
612 if let Some(token) = subscribes.get(shard_id) {
613 match token.upgrade() {
614 None => assert!(subscribes.remove(shard_id).is_some()),
615 Some(token) => {
616 return Arc::clone(&token);
617 }
618 }
619 }
620
621 let pubsub_sender = Arc::clone(&self.delegate);
622 let token = Arc::new(ShardSubscriptionToken {
623 shard_id: *shard_id,
624 sender: pubsub_sender,
625 });
626
627 assert!(
628 subscribes
629 .insert(*shard_id, Arc::downgrade(&token))
630 .is_none()
631 );
632
633 self.delegate.subscribe(shard_id);
634
635 token
636 }
637}
638
639#[derive(Debug)]
643pub struct MetricsSameProcessPubSubSender {
644 delegate_subscribe: bool,
645 metrics: Arc<Metrics>,
646 delegate: Arc<dyn PubSubSender>,
647}
648
649impl MetricsSameProcessPubSubSender {
650 pub fn new(
653 cfg: &PersistConfig,
654 pubsub_sender: Arc<dyn PubSubSender>,
655 metrics: Arc<Metrics>,
656 ) -> Self {
657 Self {
658 delegate_subscribe: PUBSUB_SAME_PROCESS_DELEGATE_ENABLED.get(cfg),
659 delegate: pubsub_sender,
660 metrics,
661 }
662 }
663}
664
665impl PubSubSender for MetricsSameProcessPubSubSender {
666 fn push_diff(&self, shard_id: &ShardId, diff: &VersionedData) {
667 self.delegate.push_diff(shard_id, diff);
668 self.metrics.pubsub_client.sender.push.succeeded.inc();
669 }
670
671 fn subscribe(self: Arc<Self>, shard_id: &ShardId) -> Arc<ShardSubscriptionToken> {
672 if self.delegate_subscribe {
673 let delegate = Arc::clone(&self.delegate);
674 delegate.subscribe(shard_id)
675 } else {
676 Arc::new(ShardSubscriptionToken {
682 shard_id: *shard_id,
683 sender: Arc::new(NoopPubSubSender),
684 })
685 }
686 }
687}
688
689#[derive(Debug)]
690pub(crate) struct NoopPubSubSender;
691
692impl PubSubSenderInternal for NoopPubSubSender {
693 fn push_diff(&self, _shard_id: &ShardId, _diff: &VersionedData) {}
694 fn subscribe(&self, _shard_id: &ShardId) {}
695 fn unsubscribe(&self, _shard_id: &ShardId) {}
696}
697
698impl PubSubSender for NoopPubSubSender {
699 fn push_diff(&self, _shard_id: &ShardId, _diff: &VersionedData) {}
700
701 fn subscribe(self: Arc<Self>, shard_id: &ShardId) -> Arc<ShardSubscriptionToken> {
702 Arc::new(ShardSubscriptionToken {
703 shard_id: *shard_id,
704 sender: self,
705 })
706 }
707}
708
709pub(crate) fn subscribe_state_cache_to_pubsub(
711 cache: Arc<StateCache>,
712 mut pubsub_receiver: Box<dyn PubSubReceiver>,
713) -> JoinHandle<()> {
714 let mut state_refs: HashMap<ShardId, Weak<dyn DynState>> = HashMap::new();
715 let receiver_metrics = cache.metrics.pubsub_client.receiver.clone();
716
717 mz_ore::task::spawn(
718 || "persist::rpc::client::state_cache_diff_apply",
719 async move {
720 while let Some(msg) = pubsub_receiver.next().await {
721 match msg.message {
722 Some(proto_pub_sub_message::Message::PushDiff(diff)) => {
723 receiver_metrics.push_received.inc();
724 let shard_id = diff.shard_id.into_rust().expect("valid shard id");
725 let diff = VersionedData {
726 seqno: diff.seqno.into_rust().expect("valid SeqNo"),
727 data: diff.diff,
728 };
729 debug!(
730 "applying pubsub diff {} {} {}",
731 shard_id,
732 diff.seqno,
733 diff.data.len()
734 );
735
736 let mut pushed_diff = false;
737 if let Some(state_ref) = state_refs.get(&shard_id) {
738 if let Some(state) = state_ref.upgrade() {
741 state.push_diff(diff.clone());
742 pushed_diff = true;
743 receiver_metrics.state_pushed_diff_fast_path.inc();
744 }
745 }
746
747 if !pushed_diff {
748 let state_ref = cache.get_state_weak(&shard_id);
753 match state_ref {
754 None => {
755 state_refs.remove(&shard_id);
756 }
757 Some(state_ref) => {
758 if let Some(state) = state_ref.upgrade() {
759 state.push_diff(diff);
760 pushed_diff = true;
761 state_refs.insert(shard_id, state_ref);
762 } else {
763 state_refs.remove(&shard_id);
764 }
765 }
766 }
767
768 if pushed_diff {
769 receiver_metrics.state_pushed_diff_slow_path_succeeded.inc();
770 } else {
771 receiver_metrics.state_pushed_diff_slow_path_failed.inc();
772 }
773 }
774
775 if let Some(send_timestamp) = msg.timestamp {
776 let send_timestamp =
777 send_timestamp.into_rust().expect("valid timestamp");
778 let now = SystemTime::now()
779 .duration_since(SystemTime::UNIX_EPOCH)
780 .expect("failed to get millis since epoch");
781 receiver_metrics
782 .approx_diff_latency_seconds
783 .observe((now.saturating_sub(send_timestamp)).as_secs_f64());
784 }
785 }
786 ref msg @ None | ref msg @ Some(_) => {
787 warn!("pubsub client received unexpected message: {:?}", msg);
788 receiver_metrics.unknown_message_received.inc();
789 }
790 }
791 }
792 },
793 )
794}
795
796#[derive(Debug)]
798pub(crate) struct PubSubState {
799 connection_id_counter: AtomicUsize,
801 shard_subscribers:
803 Arc<RwLock<BTreeMap<ShardId, BTreeMap<usize, Sender<Result<ProtoPubSubMessage, Status>>>>>>,
804 connections: Arc<RwLock<HashSet<usize>>>,
806 metrics: Arc<PubSubServerMetrics>,
808}
809
810impl PubSubState {
811 fn new_connection(
812 self: Arc<Self>,
813 notifier: Sender<Result<ProtoPubSubMessage, Status>>,
814 ) -> PubSubConnection {
815 let connection_id = self.connection_id_counter.fetch_add(1, Ordering::SeqCst);
816 {
817 debug!("inserting connid: {}", connection_id);
818 let mut connections = self.connections.write().expect("lock");
819 assert!(connections.insert(connection_id));
820 }
821
822 self.metrics.active_connections.inc();
823 PubSubConnection {
824 connection_id,
825 notifier,
826 state: self,
827 }
828 }
829
830 fn remove_connection(&self, connection_id: usize) {
831 let now = Instant::now();
832
833 {
834 debug!("removing connid: {}", connection_id);
835 let mut connections = self.connections.write().expect("lock");
836 assert!(
837 connections.remove(&connection_id),
838 "unknown connection id: {}",
839 connection_id
840 );
841 }
842
843 {
844 let mut subscribers = self.shard_subscribers.write().expect("lock poisoned");
845 subscribers.retain(|_shard, connections_for_shard| {
846 connections_for_shard.remove(&connection_id);
847 !connections_for_shard.is_empty()
848 });
849 }
850
851 self.metrics
852 .connection_cleanup_seconds
853 .inc_by(now.elapsed().as_secs_f64());
854 self.metrics.active_connections.dec();
855 }
856
857 fn push_diff(&self, connection_id: usize, shard_id: &ShardId, data: &VersionedData) {
858 let now = Instant::now();
859 self.metrics.push_call_count.inc();
860
861 assert!(
862 self.connections
863 .read()
864 .expect("lock")
865 .contains(&connection_id),
866 "unknown connection id: {}",
867 connection_id
868 );
869
870 let subscribers = self.shard_subscribers.read().expect("lock poisoned");
871 if let Some(subscribed_connections) = subscribers.get(shard_id) {
872 let mut num_sent = 0;
873 let mut data_size = 0;
874
875 for (subscribed_conn_id, tx) in subscribed_connections {
876 if *subscribed_conn_id == connection_id {
878 continue;
879 }
880 debug!(
881 "server forwarding req to conn {}: {} {} {}",
882 subscribed_conn_id,
883 &shard_id,
884 data.seqno,
885 data.data.len()
886 );
887 let req = create_request(proto_pub_sub_message::Message::PushDiff(ProtoPushDiff {
888 seqno: data.seqno.into_proto(),
889 shard_id: shard_id.to_string(),
890 diff: Bytes::clone(&data.data),
891 }));
892 data_size = req.encoded_len();
893 match tx.try_send(Ok(req)) {
894 Ok(_) => {
895 num_sent += 1;
896 }
897 Err(TrySendError::Full(_)) => {
898 self.metrics.broadcasted_diff_dropped_channel_full.inc();
899 }
900 Err(TrySendError::Closed(_)) => {}
901 };
902 }
903
904 self.metrics.broadcasted_diff_count.inc_by(num_sent);
905 self.metrics
906 .broadcasted_diff_bytes
907 .inc_by(num_sent * u64::cast_from(data_size));
908 }
909
910 self.metrics
911 .push_seconds
912 .inc_by(now.elapsed().as_secs_f64());
913 }
914
915 fn subscribe(
916 &self,
917 connection_id: usize,
918 notifier: Sender<Result<ProtoPubSubMessage, Status>>,
919 shard_id: &ShardId,
920 ) {
921 let now = Instant::now();
922 self.metrics.subscribe_call_count.inc();
923
924 assert!(
925 self.connections
926 .read()
927 .expect("lock")
928 .contains(&connection_id),
929 "unknown connection id: {}",
930 connection_id
931 );
932
933 {
934 let mut subscribed_shards = self.shard_subscribers.write().expect("lock poisoned");
935 subscribed_shards
936 .entry(*shard_id)
937 .or_default()
938 .insert(connection_id, notifier);
939 }
940
941 self.metrics
942 .subscribe_seconds
943 .inc_by(now.elapsed().as_secs_f64());
944 }
945
946 fn unsubscribe(&self, connection_id: usize, shard_id: &ShardId) {
947 let now = Instant::now();
948 self.metrics.unsubscribe_call_count.inc();
949
950 assert!(
951 self.connections
952 .read()
953 .expect("lock")
954 .contains(&connection_id),
955 "unknown connection id: {}",
956 connection_id
957 );
958
959 {
960 let mut subscribed_shards = self.shard_subscribers.write().expect("lock poisoned");
961 if let Entry::Occupied(mut entry) = subscribed_shards.entry(*shard_id) {
962 let subscribed_connections = entry.get_mut();
963 subscribed_connections.remove(&connection_id);
964
965 if subscribed_connections.is_empty() {
966 entry.remove_entry();
967 }
968 }
969 }
970
971 self.metrics
972 .unsubscribe_seconds
973 .inc_by(now.elapsed().as_secs_f64());
974 }
975
976 #[cfg(test)]
977 fn new_for_test() -> Self {
978 Self {
979 connection_id_counter: AtomicUsize::new(0),
980 shard_subscribers: Default::default(),
981 connections: Default::default(),
982 metrics: Arc::new(PubSubServerMetrics::new(&MetricsRegistry::new())),
983 }
984 }
985
986 #[cfg(test)]
987 fn active_connections(&self) -> HashSet<usize> {
988 self.connections.read().expect("lock").clone()
989 }
990
991 #[cfg(test)]
992 fn subscriptions(&self, connection_id: usize) -> HashSet<ShardId> {
993 let mut shards = HashSet::new();
994
995 let subscribers = self.shard_subscribers.read().expect("lock");
996 for (shard, subscribed_connections) in subscribers.iter() {
997 if subscribed_connections.contains_key(&connection_id) {
998 shards.insert(*shard);
999 }
1000 }
1001
1002 shards
1003 }
1004
1005 #[cfg(test)]
1006 fn shard_subscription_counts(&self) -> mz_ore::collections::HashMap<ShardId, usize> {
1007 let mut shards = mz_ore::collections::HashMap::new();
1008
1009 let subscribers = self.shard_subscribers.read().expect("lock");
1010 for (shard, subscribed_connections) in subscribers.iter() {
1011 shards.insert(*shard, subscribed_connections.len());
1012 }
1013
1014 shards
1015 }
1016}
1017
1018#[derive(Debug)]
1020pub struct PersistGrpcPubSubServer {
1021 cfg: PersistConfig,
1022 state: Arc<PubSubState>,
1023}
1024
1025impl PersistGrpcPubSubServer {
1026 pub fn new(cfg: &PersistConfig, metrics_registry: &MetricsRegistry) -> Self {
1028 let metrics = PubSubServerMetrics::new(metrics_registry);
1029 let state = Arc::new(PubSubState {
1030 connection_id_counter: AtomicUsize::new(0),
1031 shard_subscribers: Default::default(),
1032 connections: Default::default(),
1033 metrics: Arc::new(metrics),
1034 });
1035
1036 PersistGrpcPubSubServer {
1037 cfg: cfg.clone(),
1038 state,
1039 }
1040 }
1041
1042 pub fn new_same_process_connection(&self) -> PubSubClientConnection {
1046 let (tx, rx) =
1047 tokio::sync::mpsc::channel(PUBSUB_CLIENT_RECEIVER_CHANNEL_SIZE.get(&self.cfg));
1048 let sender: Arc<dyn PubSubSender> = Arc::new(SubscriptionTrackingSender::new(Arc::new(
1049 Arc::clone(&self.state).new_connection(tx),
1050 )));
1051
1052 PubSubClientConnection {
1053 sender,
1054 receiver: Box::new(
1055 ReceiverStream::new(rx).map(|x| x.expect("cannot receive grpc errors locally")),
1056 ),
1057 }
1058 }
1059
1060 pub async fn serve(self, listen_addr: SocketAddr) -> Result<(), anyhow::Error> {
1062 tonic::transport::Server::builder()
1064 .add_service(ProtoPersistPubSubServer::new(self).max_decoding_message_size(usize::MAX))
1065 .serve(listen_addr)
1066 .await?;
1067 Ok(())
1068 }
1069
1070 pub async fn serve_with_stream(
1073 self,
1074 listener: tokio_stream::wrappers::TcpListenerStream,
1075 ) -> Result<(), anyhow::Error> {
1076 tonic::transport::Server::builder()
1077 .add_service(ProtoPersistPubSubServer::new(self))
1078 .serve_with_incoming(listener)
1079 .await?;
1080 Ok(())
1081 }
1082}
1083
1084#[async_trait]
1085impl proto_persist_pub_sub_server::ProtoPersistPubSub for PersistGrpcPubSubServer {
1086 type PubSubStream = Pin<Box<dyn Stream<Item = Result<ProtoPubSubMessage, Status>> + Send>>;
1087
1088 #[mz_ore::instrument(name = "persist::rpc::server", level = "info")]
1089 async fn pub_sub(
1090 &self,
1091 request: Request<Streaming<ProtoPubSubMessage>>,
1092 ) -> Result<Response<Self::PubSubStream>, Status> {
1093 let caller_id = request
1094 .metadata()
1095 .get(AsciiMetadataKey::from_static(PERSIST_PUBSUB_CALLER_KEY))
1096 .map(|key| key.to_str().ok())
1097 .flatten()
1098 .map(|key| key.to_string())
1099 .unwrap_or_else(|| "unknown".to_string());
1100 info!("Received Persist PubSub connection from: {:?}", caller_id);
1101
1102 let mut in_stream = request.into_inner();
1103 let (tx, rx) =
1104 tokio::sync::mpsc::channel(PUBSUB_SERVER_CONNECTION_CHANNEL_SIZE.get(&self.cfg));
1105
1106 let caller = caller_id.clone();
1107 let cfg = Arc::clone(&self.cfg.configs);
1108 let server_state = Arc::clone(&self.state);
1109 let connection_span = info_span!("connection", caller_id);
1113 mz_ore::task::spawn(
1114 || format!("persist_pubsub_connection({})", caller),
1115 async move {
1116 let connection = server_state.new_connection(tx);
1117 while let Some(result) = in_stream.next().await {
1118 let req = match result {
1119 Ok(req) => req,
1120 Err(err) => {
1121 warn!("pubsub connection err: {}", err);
1122 break;
1123 }
1124 };
1125
1126 match req.message {
1127 None => {
1128 warn!("received empty message from: {}", caller_id);
1129 }
1130 Some(proto_pub_sub_message::Message::PushDiff(req)) => {
1131 let shard_id = req.shard_id.parse().expect("valid shard id");
1132 let diff = VersionedData {
1133 seqno: req.seqno.into_rust().expect("valid seqno"),
1134 data: req.diff.clone(),
1135 };
1136 if PUBSUB_PUSH_DIFF_ENABLED.get(&cfg) {
1137 connection.push_diff(&shard_id, &diff);
1138 }
1139 }
1140 Some(proto_pub_sub_message::Message::Subscribe(diff)) => {
1141 let shard_id = diff.shard_id.parse().expect("valid shard id");
1142 connection.subscribe(&shard_id);
1143 }
1144 Some(proto_pub_sub_message::Message::Unsubscribe(diff)) => {
1145 let shard_id = diff.shard_id.parse().expect("valid shard id");
1146 connection.unsubscribe(&shard_id);
1147 }
1148 }
1149 }
1150
1151 info!("Persist PubSub connection ended: {:?}", caller_id);
1152 }
1153 .instrument(connection_span),
1154 );
1155
1156 let out_stream: Self::PubSubStream = Box::pin(ReceiverStream::new(rx));
1157 Ok(Response::new(out_stream))
1158 }
1159}
1160
1161#[derive(Debug)]
1165pub(crate) struct PubSubConnection {
1166 connection_id: usize,
1167 notifier: Sender<Result<ProtoPubSubMessage, Status>>,
1168 state: Arc<PubSubState>,
1169}
1170
1171impl PubSubSenderInternal for PubSubConnection {
1172 fn push_diff(&self, shard_id: &ShardId, diff: &VersionedData) {
1173 self.state.push_diff(self.connection_id, shard_id, diff)
1174 }
1175
1176 fn subscribe(&self, shard_id: &ShardId) {
1177 self.state
1178 .subscribe(self.connection_id, self.notifier.clone(), shard_id)
1179 }
1180
1181 fn unsubscribe(&self, shard_id: &ShardId) {
1182 self.state.unsubscribe(self.connection_id, shard_id)
1183 }
1184}
1185
1186impl Drop for PubSubConnection {
1187 fn drop(&mut self) {
1188 self.state.remove_connection(self.connection_id)
1189 }
1190}
1191
1192#[cfg(test)]
1193mod pubsub_state {
1194 use std::str::FromStr;
1195 use std::sync::Arc;
1196 use std::sync::LazyLock;
1197
1198 use bytes::Bytes;
1199 use mz_ore::collections::HashSet;
1200 use mz_persist::location::{SeqNo, VersionedData};
1201 use mz_proto::RustType;
1202 use tokio::sync::mpsc::Receiver;
1203 use tokio::sync::mpsc::error::TryRecvError;
1204 use tonic::Status;
1205
1206 use crate::ShardId;
1207 use crate::internal::service::ProtoPubSubMessage;
1208 use crate::internal::service::proto_pub_sub_message::Message;
1209 use crate::rpc::{PubSubSenderInternal, PubSubState};
1210
1211 static SHARD_ID_0: LazyLock<ShardId> =
1212 LazyLock::new(|| ShardId::from_str("s00000000-0000-0000-0000-000000000000").unwrap());
1213 static SHARD_ID_1: LazyLock<ShardId> =
1214 LazyLock::new(|| ShardId::from_str("s11111111-1111-1111-1111-111111111111").unwrap());
1215
1216 const VERSIONED_DATA_0: VersionedData = VersionedData {
1217 seqno: SeqNo(0),
1218 data: Bytes::from_static(&[0, 1, 2, 3]),
1219 };
1220
1221 const VERSIONED_DATA_1: VersionedData = VersionedData {
1222 seqno: SeqNo(1),
1223 data: Bytes::from_static(&[4, 5, 6, 7]),
1224 };
1225
1226 #[mz_ore::test]
1227 #[should_panic(expected = "unknown connection id: 100")]
1228 fn test_zero_connections_push_diff() {
1229 let state = Arc::new(PubSubState::new_for_test());
1230 state.push_diff(100, &SHARD_ID_0, &VERSIONED_DATA_0);
1231 }
1232
1233 #[mz_ore::test]
1234 #[should_panic(expected = "unknown connection id: 100")]
1235 fn test_zero_connections_subscribe() {
1236 let state = Arc::new(PubSubState::new_for_test());
1237 let (tx, _) = tokio::sync::mpsc::channel(100);
1238 state.subscribe(100, tx, &SHARD_ID_0);
1239 }
1240
1241 #[mz_ore::test]
1242 #[should_panic(expected = "unknown connection id: 100")]
1243 fn test_zero_connections_unsubscribe() {
1244 let state = Arc::new(PubSubState::new_for_test());
1245 state.unsubscribe(100, &SHARD_ID_0);
1246 }
1247
1248 #[mz_ore::test]
1249 #[should_panic(expected = "unknown connection id: 100")]
1250 fn test_zero_connections_remove() {
1251 let state = Arc::new(PubSubState::new_for_test());
1252 state.remove_connection(100)
1253 }
1254
1255 #[mz_ore::test]
1256 fn test_single_connection() {
1257 let state = Arc::new(PubSubState::new_for_test());
1258
1259 let (tx, mut rx) = tokio::sync::mpsc::channel(100);
1260 let connection = Arc::clone(&state).new_connection(tx);
1261
1262 assert_eq!(
1263 state.active_connections(),
1264 HashSet::from([connection.connection_id])
1265 );
1266
1267 assert!(matches!(rx.try_recv(), Err(TryRecvError::Empty)));
1269
1270 connection.push_diff(
1271 &SHARD_ID_0,
1272 &VersionedData {
1273 seqno: SeqNo::minimum(),
1274 data: Bytes::new(),
1275 },
1276 );
1277
1278 assert!(matches!(rx.try_recv(), Err(TryRecvError::Empty)));
1280
1281 connection.subscribe(&SHARD_ID_0);
1283 assert_eq!(
1284 state.subscriptions(connection.connection_id),
1285 HashSet::from([SHARD_ID_0.clone()])
1286 );
1287
1288 connection.unsubscribe(&SHARD_ID_0);
1290 assert!(state.subscriptions(connection.connection_id).is_empty());
1291
1292 connection.subscribe(&SHARD_ID_0);
1294 connection.subscribe(&SHARD_ID_1);
1295 assert_eq!(
1296 state.subscriptions(connection.connection_id),
1297 HashSet::from([*SHARD_ID_0, *SHARD_ID_1])
1298 );
1299
1300 connection.subscribe(&SHARD_ID_0);
1302 connection.subscribe(&SHARD_ID_0);
1303 assert_eq!(
1304 state.subscriptions(connection.connection_id),
1305 HashSet::from([*SHARD_ID_0, *SHARD_ID_1])
1306 );
1307
1308 let connection_id = connection.connection_id;
1310 drop(connection);
1311 assert!(state.subscriptions(connection_id).is_empty());
1312 assert!(state.active_connections().is_empty());
1313 }
1314
1315 #[mz_ore::test]
1316 fn test_many_connection() {
1317 let state = Arc::new(PubSubState::new_for_test());
1318
1319 let (tx1, mut rx1) = tokio::sync::mpsc::channel(100);
1320 let conn1 = Arc::clone(&state).new_connection(tx1);
1321
1322 let (tx2, mut rx2) = tokio::sync::mpsc::channel(100);
1323 let conn2 = Arc::clone(&state).new_connection(tx2);
1324
1325 let (tx3, mut rx3) = tokio::sync::mpsc::channel(100);
1326 let conn3 = Arc::clone(&state).new_connection(tx3);
1327
1328 conn1.subscribe(&SHARD_ID_0);
1329 conn2.subscribe(&SHARD_ID_0);
1330 conn2.subscribe(&SHARD_ID_1);
1331
1332 assert_eq!(
1333 state.active_connections(),
1334 HashSet::from([
1335 conn1.connection_id,
1336 conn2.connection_id,
1337 conn3.connection_id
1338 ])
1339 );
1340
1341 conn3.push_diff(&SHARD_ID_0, &VERSIONED_DATA_0);
1343 assert_push(&mut rx1, &SHARD_ID_0, &VERSIONED_DATA_0);
1344 assert_push(&mut rx2, &SHARD_ID_0, &VERSIONED_DATA_0);
1345 assert!(matches!(rx3.try_recv(), Err(TryRecvError::Empty)));
1346
1347 conn1.push_diff(&SHARD_ID_0, &VERSIONED_DATA_0);
1349 assert!(matches!(rx1.try_recv(), Err(TryRecvError::Empty)));
1350 assert_push(&mut rx2, &SHARD_ID_0, &VERSIONED_DATA_0);
1351 assert!(matches!(rx3.try_recv(), Err(TryRecvError::Empty)));
1352
1353 conn3.push_diff(&SHARD_ID_1, &VERSIONED_DATA_1);
1355 assert!(matches!(rx1.try_recv(), Err(TryRecvError::Empty)));
1356 assert_push(&mut rx2, &SHARD_ID_1, &VERSIONED_DATA_1);
1357 assert!(matches!(rx3.try_recv(), Err(TryRecvError::Empty)));
1358
1359 conn2.unsubscribe(&SHARD_ID_1);
1361 conn3.push_diff(&SHARD_ID_1, &VERSIONED_DATA_1);
1362 assert!(matches!(rx1.try_recv(), Err(TryRecvError::Empty)));
1363 assert!(matches!(rx2.try_recv(), Err(TryRecvError::Empty)));
1364 assert!(matches!(rx3.try_recv(), Err(TryRecvError::Empty)));
1365
1366 let conn1_id = conn1.connection_id;
1368 drop(conn1);
1369 conn3.push_diff(&SHARD_ID_0, &VERSIONED_DATA_0);
1370 assert!(matches!(rx1.try_recv(), Err(TryRecvError::Disconnected)));
1371 assert_push(&mut rx2, &SHARD_ID_0, &VERSIONED_DATA_0);
1372 assert!(matches!(rx3.try_recv(), Err(TryRecvError::Empty)));
1373
1374 assert!(state.subscriptions(conn1_id).is_empty());
1375 assert_eq!(
1376 state.subscriptions(conn2.connection_id),
1377 HashSet::from([*SHARD_ID_0])
1378 );
1379 assert_eq!(state.subscriptions(conn3.connection_id), HashSet::new());
1380 assert_eq!(
1381 state.active_connections(),
1382 HashSet::from([conn2.connection_id, conn3.connection_id])
1383 );
1384 }
1385
1386 fn assert_push(
1387 rx: &mut Receiver<Result<ProtoPubSubMessage, Status>>,
1388 shard: &ShardId,
1389 data: &VersionedData,
1390 ) {
1391 let message = rx
1392 .try_recv()
1393 .expect("message in channel")
1394 .expect("pubsub")
1395 .message
1396 .expect("proto contains message");
1397 match message {
1398 Message::PushDiff(x) => {
1399 assert_eq!(x.shard_id, shard.into_proto());
1400 assert_eq!(x.seqno, data.seqno.into_proto());
1401 assert_eq!(x.diff, data.data);
1402 }
1403 Message::Subscribe(_) | Message::Unsubscribe(_) => panic!("unexpected message type"),
1404 };
1405 }
1406}
1407
1408#[cfg(test)]
1409mod grpc {
1410 use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
1411 use std::str::FromStr;
1412 use std::sync::Arc;
1413 use std::time::{Duration, Instant};
1414
1415 use bytes::Bytes;
1416 use futures_util::FutureExt;
1417 use mz_dyncfg::ConfigUpdates;
1418 use mz_ore::assert_none;
1419 use mz_ore::collections::HashMap;
1420 use mz_ore::metrics::MetricsRegistry;
1421 use mz_persist::location::{SeqNo, VersionedData};
1422 use mz_proto::RustType;
1423 use std::sync::LazyLock;
1424 use tokio::net::TcpListener;
1425 use tokio_stream::StreamExt;
1426 use tokio_stream::wrappers::TcpListenerStream;
1427
1428 use crate::ShardId;
1429 use crate::cfg::PersistConfig;
1430 use crate::internal::service::ProtoPubSubMessage;
1431 use crate::internal::service::proto_pub_sub_message::Message;
1432 use crate::metrics::Metrics;
1433 use crate::rpc::{
1434 GrpcPubSubClient, PUBSUB_CLIENT_ENABLED, PUBSUB_RECONNECT_BACKOFF, PersistGrpcPubSubServer,
1435 PersistPubSubClient, PersistPubSubClientConfig, PubSubState,
1436 };
1437
1438 static SHARD_ID_0: LazyLock<ShardId> =
1439 LazyLock::new(|| ShardId::from_str("s00000000-0000-0000-0000-000000000000").unwrap());
1440 static SHARD_ID_1: LazyLock<ShardId> =
1441 LazyLock::new(|| ShardId::from_str("s11111111-1111-1111-1111-111111111111").unwrap());
1442 const VERSIONED_DATA_0: VersionedData = VersionedData {
1443 seqno: SeqNo(0),
1444 data: Bytes::from_static(&[0, 1, 2, 3]),
1445 };
1446 const VERSIONED_DATA_1: VersionedData = VersionedData {
1447 seqno: SeqNo(1),
1448 data: Bytes::from_static(&[4, 5, 6, 7]),
1449 };
1450
1451 const CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
1452 const SUBSCRIPTIONS_TIMEOUT: Duration = Duration::from_secs(3);
1453 const SERVER_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(2);
1454
1455 #[mz_ore::test]
1460 #[cfg_attr(miri, ignore)] fn grpc_server() {
1462 let metrics = Arc::new(Metrics::new(
1463 &test_persist_config(),
1464 &MetricsRegistry::new(),
1465 ));
1466 let server_runtime = tokio::runtime::Runtime::new().expect("server runtime");
1467 let client_runtime = tokio::runtime::Runtime::new().expect("client runtime");
1468
1469 let (addr, tcp_listener_stream) = server_runtime.block_on(new_tcp_listener());
1471 let server_state = server_runtime.block_on(spawn_server(tcp_listener_stream));
1472
1473 {
1475 let _guard = client_runtime.enter();
1476 mz_ore::task::spawn(|| "client".to_string(), async move {
1477 let client = GrpcPubSubClient::connect(
1478 PersistPubSubClientConfig {
1479 url: format!("http://{}", addr),
1480 caller_id: "client".to_string(),
1481 persist_cfg: test_persist_config(),
1482 },
1483 metrics,
1484 );
1485 let _token = client.sender.subscribe(&SHARD_ID_0);
1486 tokio::time::sleep(Duration::MAX).await;
1487 });
1488 }
1489
1490 server_runtime.block_on(async {
1492 poll_until_true(CONNECT_TIMEOUT, || {
1493 server_state.active_connections().len() == 1
1494 })
1495 .await;
1496 poll_until_true(SUBSCRIPTIONS_TIMEOUT, || {
1497 server_state.shard_subscription_counts() == HashMap::from([(*SHARD_ID_0, 1)])
1498 })
1499 .await
1500 });
1501
1502 client_runtime.shutdown_timeout(SERVER_SHUTDOWN_TIMEOUT);
1504
1505 server_runtime.block_on(async {
1507 poll_until_true(CONNECT_TIMEOUT, || {
1508 server_state.active_connections().is_empty()
1509 })
1510 .await;
1511 poll_until_true(SUBSCRIPTIONS_TIMEOUT, || {
1512 server_state.shard_subscription_counts() == HashMap::new()
1513 })
1514 .await
1515 });
1516 }
1517
1518 #[mz_ore::test]
1519 #[cfg_attr(miri, ignore)] fn grpc_client_sender_reconnects() {
1521 let metrics = Arc::new(Metrics::new(
1522 &test_persist_config(),
1523 &MetricsRegistry::new(),
1524 ));
1525 let server_runtime = tokio::runtime::Runtime::new().expect("server runtime");
1526 let client_runtime = tokio::runtime::Runtime::new().expect("client runtime");
1527 let (addr, tcp_listener_stream) = server_runtime.block_on(new_tcp_listener());
1528
1529 let client = client_runtime.block_on(async {
1531 GrpcPubSubClient::connect(
1532 PersistPubSubClientConfig {
1533 url: format!("http://{}", addr),
1534 caller_id: "client".to_string(),
1535 persist_cfg: test_persist_config(),
1536 },
1537 metrics,
1538 )
1539 });
1540
1541 let _token = Arc::clone(&client.sender).subscribe(&SHARD_ID_0);
1543 let _token_2 = Arc::clone(&client.sender).subscribe(&SHARD_ID_1);
1545 drop(_token_2);
1546
1547 let server_state = server_runtime.block_on(spawn_server(tcp_listener_stream));
1549
1550 server_runtime.block_on(async {
1551 poll_until_true(CONNECT_TIMEOUT, || {
1553 server_state.active_connections().len() == 1
1554 })
1555 .await;
1556
1557 poll_until_true(SUBSCRIPTIONS_TIMEOUT, || {
1560 server_state.shard_subscription_counts() == HashMap::from([(*SHARD_ID_0, 1)])
1561 })
1562 .await;
1563 });
1564
1565 server_runtime.shutdown_timeout(SERVER_SHUTDOWN_TIMEOUT);
1567
1568 let _token_2 = Arc::clone(&client.sender).subscribe(&SHARD_ID_1);
1570
1571 let server_runtime = tokio::runtime::Runtime::new().expect("server runtime");
1573 let tcp_listener_stream = server_runtime.block_on(async {
1574 TcpListenerStream::new(
1575 TcpListener::bind(addr)
1576 .await
1577 .expect("can bind to previous addr"),
1578 )
1579 });
1580 let server_state = server_runtime.block_on(spawn_server(tcp_listener_stream));
1581
1582 server_runtime.block_on(async {
1583 poll_until_true(CONNECT_TIMEOUT, || {
1585 server_state.active_connections().len() == 1
1586 })
1587 .await;
1588
1589 poll_until_true(SUBSCRIPTIONS_TIMEOUT, || {
1592 server_state.shard_subscription_counts()
1593 == HashMap::from([(*SHARD_ID_0, 1), (*SHARD_ID_1, 1)])
1594 })
1595 .await;
1596 });
1597 }
1598
1599 #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
1600 #[cfg_attr(miri, ignore)] async fn grpc_client_sender_subscription_tokens() {
1602 let metrics = Arc::new(Metrics::new(
1603 &test_persist_config(),
1604 &MetricsRegistry::new(),
1605 ));
1606
1607 let (addr, tcp_listener_stream) = new_tcp_listener().await;
1608 let server_state = spawn_server(tcp_listener_stream).await;
1609
1610 let client = GrpcPubSubClient::connect(
1611 PersistPubSubClientConfig {
1612 url: format!("http://{}", addr),
1613 caller_id: "client".to_string(),
1614 persist_cfg: test_persist_config(),
1615 },
1616 metrics,
1617 );
1618
1619 poll_until_true(CONNECT_TIMEOUT, || {
1621 server_state.active_connections().len() == 1
1622 })
1623 .await;
1624
1625 let token = Arc::clone(&client.sender).subscribe(&SHARD_ID_0);
1627 poll_until_true(SUBSCRIPTIONS_TIMEOUT, || {
1628 server_state.shard_subscription_counts() == HashMap::from([(*SHARD_ID_0, 1)])
1629 })
1630 .await;
1631
1632 drop(token);
1634 poll_until_true(SUBSCRIPTIONS_TIMEOUT, || {
1635 server_state.shard_subscription_counts() == HashMap::new()
1636 })
1637 .await;
1638
1639 let token = Arc::clone(&client.sender).subscribe(&SHARD_ID_0);
1641 poll_until_true(SUBSCRIPTIONS_TIMEOUT, || {
1642 server_state.shard_subscription_counts() == HashMap::from([(*SHARD_ID_0, 1)])
1643 })
1644 .await;
1645
1646 let token2 = Arc::clone(&client.sender).subscribe(&SHARD_ID_0);
1648 let token3 = Arc::clone(&client.sender).subscribe(&SHARD_ID_0);
1649 assert_eq!(Arc::strong_count(&token), 3);
1650 poll_until_true(SUBSCRIPTIONS_TIMEOUT, || {
1651 server_state.shard_subscription_counts() == HashMap::from([(*SHARD_ID_0, 1)])
1652 })
1653 .await;
1654
1655 drop(token);
1657 drop(token2);
1658 drop(token3);
1659 poll_until_true(SUBSCRIPTIONS_TIMEOUT, || {
1660 server_state.shard_subscription_counts() == HashMap::new()
1661 })
1662 .await;
1663
1664 let _token0 = Arc::clone(&client.sender).subscribe(&SHARD_ID_0);
1666 let _token1 = Arc::clone(&client.sender).subscribe(&SHARD_ID_1);
1667 poll_until_true(SUBSCRIPTIONS_TIMEOUT, || {
1668 server_state.shard_subscription_counts()
1669 == HashMap::from([(*SHARD_ID_0, 1), (*SHARD_ID_1, 1)])
1670 })
1671 .await;
1672 }
1673
1674 #[mz_ore::test]
1675 #[cfg_attr(miri, ignore)] fn grpc_client_receiver() {
1677 let metrics = Arc::new(Metrics::new(
1678 &PersistConfig::new_for_tests(),
1679 &MetricsRegistry::new(),
1680 ));
1681 let server_runtime = tokio::runtime::Runtime::new().expect("server runtime");
1682 let client_runtime = tokio::runtime::Runtime::new().expect("client runtime");
1683 let (addr, tcp_listener_stream) = server_runtime.block_on(new_tcp_listener());
1684
1685 let mut client_1 = client_runtime.block_on(async {
1687 GrpcPubSubClient::connect(
1688 PersistPubSubClientConfig {
1689 url: format!("http://{}", addr),
1690 caller_id: "client_1".to_string(),
1691 persist_cfg: test_persist_config(),
1692 },
1693 Arc::clone(&metrics),
1694 )
1695 });
1696 let mut client_2 = client_runtime.block_on(async {
1697 GrpcPubSubClient::connect(
1698 PersistPubSubClientConfig {
1699 url: format!("http://{}", addr),
1700 caller_id: "client_2".to_string(),
1701 persist_cfg: test_persist_config(),
1702 },
1703 metrics,
1704 )
1705 });
1706
1707 assert_none!(client_1.receiver.next().now_or_never());
1712 assert_none!(client_2.receiver.next().now_or_never());
1713
1714 let server_state = server_runtime.block_on(spawn_server(tcp_listener_stream));
1716
1717 server_runtime.block_on(poll_until_true(CONNECT_TIMEOUT, || {
1719 server_state.active_connections().len() == 2
1720 }));
1721
1722 assert_none!(client_1.receiver.next().now_or_never());
1724 assert_none!(client_2.receiver.next().now_or_never());
1725
1726 let _token_client_1 = Arc::clone(&client_1.sender).subscribe(&SHARD_ID_0);
1728 let _token_client_2 = Arc::clone(&client_2.sender).subscribe(&SHARD_ID_0);
1729 server_runtime.block_on(poll_until_true(SUBSCRIPTIONS_TIMEOUT, || {
1730 server_state.shard_subscription_counts() == HashMap::from([(*SHARD_ID_0, 2)])
1731 }));
1732
1733 client_1.sender.push_diff(&SHARD_ID_0, &VERSIONED_DATA_1);
1735 assert_none!(client_1.receiver.next().now_or_never());
1736 client_runtime.block_on(async {
1737 assert_push(
1738 client_2.receiver.next().await.expect("has diff"),
1739 &SHARD_ID_0,
1740 &VERSIONED_DATA_1,
1741 )
1742 });
1743
1744 server_runtime.shutdown_timeout(SERVER_SHUTDOWN_TIMEOUT);
1746
1747 assert_none!(client_1.receiver.next().now_or_never());
1749 assert_none!(client_2.receiver.next().now_or_never());
1750
1751 let server_runtime = tokio::runtime::Runtime::new().expect("server runtime");
1753 let tcp_listener_stream = server_runtime.block_on(async {
1754 TcpListenerStream::new(
1755 TcpListener::bind(addr)
1756 .await
1757 .expect("can bind to previous addr"),
1758 )
1759 });
1760 let server_state = server_runtime.block_on(spawn_server(tcp_listener_stream));
1761
1762 server_runtime.block_on(async {
1764 poll_until_true(CONNECT_TIMEOUT, || {
1765 server_state.active_connections().len() == 2
1766 })
1767 .await;
1768 poll_until_true(SUBSCRIPTIONS_TIMEOUT, || {
1769 server_state.shard_subscription_counts() == HashMap::from([(*SHARD_ID_0, 2)])
1770 })
1771 .await;
1772 });
1773
1774 client_2.sender.push_diff(&SHARD_ID_0, &VERSIONED_DATA_0);
1777 client_runtime.block_on(async {
1778 assert_push(
1779 client_1.receiver.next().await.expect("has diff"),
1780 &SHARD_ID_0,
1781 &VERSIONED_DATA_0,
1782 )
1783 });
1784 assert_none!(client_2.receiver.next().now_or_never());
1785 }
1786
1787 async fn new_tcp_listener() -> (SocketAddr, TcpListenerStream) {
1788 let addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0));
1789 let tcp_listener = TcpListener::bind(addr).await.expect("tcp listener");
1790
1791 (
1792 tcp_listener.local_addr().expect("bound to local address"),
1793 TcpListenerStream::new(tcp_listener),
1794 )
1795 }
1796
1797 #[allow(clippy::unused_async)]
1798 async fn spawn_server(tcp_listener_stream: TcpListenerStream) -> Arc<PubSubState> {
1799 let server = PersistGrpcPubSubServer::new(&test_persist_config(), &MetricsRegistry::new());
1800 let server_state = Arc::clone(&server.state);
1801
1802 let _server_task = mz_ore::task::spawn(|| "server".to_string(), async move {
1803 server.serve_with_stream(tcp_listener_stream).await
1804 });
1805 server_state
1806 }
1807
1808 async fn poll_until_true<F>(timeout: Duration, f: F)
1809 where
1810 F: Fn() -> bool,
1811 {
1812 let now = Instant::now();
1813 loop {
1814 if f() {
1815 return;
1816 }
1817
1818 if now.elapsed() > timeout {
1819 panic!("timed out");
1820 }
1821
1822 tokio::time::sleep(Duration::from_millis(1)).await;
1823 }
1824 }
1825
1826 fn assert_push(message: ProtoPubSubMessage, shard: &ShardId, data: &VersionedData) {
1827 let message = message.message.expect("proto contains message");
1828 match message {
1829 Message::PushDiff(x) => {
1830 assert_eq!(x.shard_id, shard.into_proto());
1831 assert_eq!(x.seqno, data.seqno.into_proto());
1832 assert_eq!(x.diff, data.data);
1833 }
1834 Message::Subscribe(_) | Message::Unsubscribe(_) => panic!("unexpected message type"),
1835 };
1836 }
1837
1838 fn test_persist_config() -> PersistConfig {
1839 let cfg = PersistConfig::new_for_tests();
1840
1841 let mut updates = ConfigUpdates::default();
1842 updates.add(&PUBSUB_CLIENT_ENABLED, true);
1843 updates.add(&PUBSUB_RECONNECT_BACKOFF, Duration::ZERO);
1844 cfg.apply_from(&updates);
1845
1846 cfg
1847 }
1848}