1use std::any::Any;
13use std::collections::BTreeMap;
14use std::collections::btree_map::Entry;
15use std::fmt::Debug;
16use std::future::Future;
17use std::sync::{Arc, RwLock, TryLockError, Weak};
18use std::time::{Duration, Instant};
19
20use differential_dataflow::difference::Semigroup;
21use differential_dataflow::lattice::Lattice;
22use mz_ore::instrument;
23use mz_ore::metrics::MetricsRegistry;
24use mz_ore::task::{AbortOnDropHandle, JoinHandle};
25use mz_ore::url::SensitiveUrl;
26use mz_persist::cfg::{BlobConfig, ConsensusConfig};
27use mz_persist::location::{
28 BLOB_GET_LIVENESS_KEY, Blob, CONSENSUS_HEAD_LIVENESS_KEY, Consensus, ExternalError, Tasked,
29 VersionedData,
30};
31use mz_persist_types::{Codec, Codec64};
32use timely::progress::Timestamp;
33use tokio::sync::{Mutex, OnceCell};
34use tracing::debug;
35
36use crate::async_runtime::IsolatedRuntime;
37use crate::error::{CodecConcreteType, CodecMismatch};
38use crate::internal::cache::BlobMemCache;
39use crate::internal::machine::retry_external;
40use crate::internal::metrics::{LockMetrics, Metrics, MetricsBlob, MetricsConsensus, ShardMetrics};
41use crate::internal::state::TypedState;
42use crate::internal::watch::StateWatchNotifier;
43use crate::rpc::{PubSubClientConnection, PubSubSender, ShardSubscriptionToken};
44use crate::schema::SchemaCacheMaps;
45use crate::{Diagnostics, PersistClient, PersistConfig, PersistLocation, ShardId};
46
47#[derive(Debug)]
56pub struct PersistClientCache {
57 pub cfg: PersistConfig,
59 pub(crate) metrics: Arc<Metrics>,
60 blob_by_uri: Mutex<BTreeMap<SensitiveUrl, (RttLatencyTask, Arc<dyn Blob>)>>,
61 consensus_by_uri: Mutex<BTreeMap<SensitiveUrl, (RttLatencyTask, Arc<dyn Consensus>)>>,
62 isolated_runtime: Arc<IsolatedRuntime>,
63 pub(crate) state_cache: Arc<StateCache>,
64 pubsub_sender: Arc<dyn PubSubSender>,
65 _pubsub_receiver_task: JoinHandle<()>,
66}
67
68#[derive(Debug)]
69struct RttLatencyTask(#[allow(dead_code)] AbortOnDropHandle<()>);
70
71impl PersistClientCache {
72 pub fn new<F>(cfg: PersistConfig, registry: &MetricsRegistry, pubsub: F) -> Self
74 where
75 F: FnOnce(&PersistConfig, Arc<Metrics>) -> PubSubClientConnection,
76 {
77 let metrics = Arc::new(Metrics::new(&cfg, registry));
78 let pubsub_client = pubsub(&cfg, Arc::clone(&metrics));
79
80 let state_cache = Arc::new(StateCache::new(
81 &cfg,
82 Arc::clone(&metrics),
83 Arc::clone(&pubsub_client.sender),
84 ));
85 let _pubsub_receiver_task = crate::rpc::subscribe_state_cache_to_pubsub(
86 Arc::clone(&state_cache),
87 pubsub_client.receiver,
88 );
89 let isolated_runtime =
90 IsolatedRuntime::new(registry, Some(cfg.isolated_runtime_worker_threads));
91
92 PersistClientCache {
93 cfg,
94 metrics,
95 blob_by_uri: Mutex::new(BTreeMap::new()),
96 consensus_by_uri: Mutex::new(BTreeMap::new()),
97 isolated_runtime: Arc::new(isolated_runtime),
98 state_cache,
99 pubsub_sender: pubsub_client.sender,
100 _pubsub_receiver_task,
101 }
102 }
103
104 pub fn new_no_metrics() -> Self {
107 Self::new(
108 PersistConfig::new_for_tests(),
109 &MetricsRegistry::new(),
110 |_, _| PubSubClientConnection::noop(),
111 )
112 }
113
114 pub fn cfg(&self) -> &PersistConfig {
116 &self.cfg
117 }
118
119 pub fn metrics(&self) -> &Arc<Metrics> {
121 &self.metrics
122 }
123
124 pub fn shard_metrics(&self, shard_id: &ShardId, name: &str) -> Arc<ShardMetrics> {
126 self.metrics.shards.shard(shard_id, name)
127 }
128
129 pub fn clear_state_cache(&mut self) {
133 self.state_cache = Arc::new(StateCache::new(
134 &self.cfg,
135 Arc::clone(&self.metrics),
136 Arc::clone(&self.pubsub_sender),
137 ))
138 }
139
140 #[instrument(level = "debug")]
145 pub async fn open(&self, location: PersistLocation) -> Result<PersistClient, ExternalError> {
146 let blob = self.open_blob(location.blob_uri).await?;
147 let consensus = self.open_consensus(location.consensus_uri).await?;
148 PersistClient::new(
149 self.cfg.clone(),
150 blob,
151 consensus,
152 Arc::clone(&self.metrics),
153 Arc::clone(&self.isolated_runtime),
154 Arc::clone(&self.state_cache),
155 Arc::clone(&self.pubsub_sender),
156 )
157 }
158
159 const PROMETHEUS_SCRAPE_INTERVAL: Duration = Duration::from_secs(60);
161
162 async fn open_consensus(
163 &self,
164 consensus_uri: SensitiveUrl,
165 ) -> Result<Arc<dyn Consensus>, ExternalError> {
166 let mut consensus_by_uri = self.consensus_by_uri.lock().await;
167 let consensus = match consensus_by_uri.entry(consensus_uri) {
168 Entry::Occupied(x) => Arc::clone(&x.get().1),
169 Entry::Vacant(x) => {
170 let consensus = ConsensusConfig::try_from(
173 x.key(),
174 Box::new(self.cfg.clone()),
175 self.metrics.postgres_consensus.clone(),
176 )?;
177 let consensus =
178 retry_external(&self.metrics.retries.external.consensus_open, || {
179 consensus.clone().open()
180 })
181 .await;
182 let consensus =
183 Arc::new(MetricsConsensus::new(consensus, Arc::clone(&self.metrics)));
184 let consensus = Arc::new(Tasked(consensus));
185 let task = consensus_rtt_latency_task(
186 Arc::clone(&consensus),
187 Arc::clone(&self.metrics),
188 Self::PROMETHEUS_SCRAPE_INTERVAL,
189 )
190 .await;
191 Arc::clone(
192 &x.insert((RttLatencyTask(task.abort_on_drop()), consensus))
193 .1,
194 )
195 }
196 };
197 Ok(consensus)
198 }
199
200 async fn open_blob(&self, blob_uri: SensitiveUrl) -> Result<Arc<dyn Blob>, ExternalError> {
201 let mut blob_by_uri = self.blob_by_uri.lock().await;
202 let blob = match blob_by_uri.entry(blob_uri) {
203 Entry::Occupied(x) => Arc::clone(&x.get().1),
204 Entry::Vacant(x) => {
205 let blob = BlobConfig::try_from(
208 x.key(),
209 Box::new(self.cfg.clone()),
210 self.metrics.s3_blob.clone(),
211 Arc::clone(&self.cfg.configs),
212 )
213 .await?;
214 let blob = retry_external(&self.metrics.retries.external.blob_open, || {
215 blob.clone().open()
216 })
217 .await;
218 let blob = Arc::new(MetricsBlob::new(blob, Arc::clone(&self.metrics)));
219 let blob = Arc::new(Tasked(blob));
220 let task = blob_rtt_latency_task(
221 Arc::clone(&blob),
222 Arc::clone(&self.metrics),
223 Self::PROMETHEUS_SCRAPE_INTERVAL,
224 )
225 .await;
226 let blob = BlobMemCache::new(&self.cfg, Arc::clone(&self.metrics), blob);
229 Arc::clone(&x.insert((RttLatencyTask(task.abort_on_drop()), blob)).1)
230 }
231 };
232 Ok(blob)
233 }
234}
235
236#[allow(clippy::unused_async)]
250async fn blob_rtt_latency_task(
251 blob: Arc<Tasked<MetricsBlob>>,
252 metrics: Arc<Metrics>,
253 measurement_interval: Duration,
254) -> JoinHandle<()> {
255 mz_ore::task::spawn(|| "persist::blob_rtt_latency", async move {
256 let mut next_measurement = tokio::time::Instant::now();
259 loop {
260 tokio::time::sleep_until(next_measurement).await;
261 let start = Instant::now();
262 match blob.get(BLOB_GET_LIVENESS_KEY).await {
263 Ok(_) => {
264 metrics.blob.rtt_latency.set(start.elapsed().as_secs_f64());
265 }
266 Err(_) => {
267 }
271 }
272 next_measurement = tokio::time::Instant::now() + measurement_interval;
273 }
274 })
275}
276
277#[allow(clippy::unused_async)]
291async fn consensus_rtt_latency_task(
292 consensus: Arc<Tasked<MetricsConsensus>>,
293 metrics: Arc<Metrics>,
294 measurement_interval: Duration,
295) -> JoinHandle<()> {
296 mz_ore::task::spawn(|| "persist::consensus_rtt_latency", async move {
297 let mut next_measurement = tokio::time::Instant::now();
300 loop {
301 tokio::time::sleep_until(next_measurement).await;
302 let start = Instant::now();
303 match consensus.head(CONSENSUS_HEAD_LIVENESS_KEY).await {
304 Ok(_) => {
305 metrics
306 .consensus
307 .rtt_latency
308 .set(start.elapsed().as_secs_f64());
309 }
310 Err(_) => {
311 }
315 }
316 next_measurement = tokio::time::Instant::now() + measurement_interval;
317 }
318 })
319}
320
321pub(crate) trait DynState: Debug + Send + Sync {
322 fn codecs(&self) -> (String, String, String, String, Option<CodecConcreteType>);
323 fn as_any(self: Arc<Self>) -> Arc<dyn Any + Send + Sync>;
324 fn push_diff(&self, diff: VersionedData);
325}
326
327impl<K, V, T, D> DynState for LockingTypedState<K, V, T, D>
328where
329 K: Codec,
330 V: Codec,
331 T: Timestamp + Lattice + Codec64 + Sync,
332 D: Codec64,
333{
334 fn codecs(&self) -> (String, String, String, String, Option<CodecConcreteType>) {
335 (
336 K::codec_name(),
337 V::codec_name(),
338 T::codec_name(),
339 D::codec_name(),
340 Some(CodecConcreteType(std::any::type_name::<(K, V, T, D)>())),
341 )
342 }
343
344 fn as_any(self: Arc<Self>) -> Arc<dyn Any + Send + Sync> {
345 self
346 }
347
348 fn push_diff(&self, diff: VersionedData) {
349 self.write_lock(&self.metrics.locks.applier_write, |state| {
350 let seqno_before = state.seqno;
351 state.apply_encoded_diffs(&self.cfg, &self.metrics, std::iter::once(&diff));
352 let seqno_after = state.seqno;
353 assert!(seqno_after >= seqno_before);
354
355 if seqno_before != seqno_after {
356 debug!(
357 "applied pushed diff {}. seqno {} -> {}.",
358 state.shard_id, seqno_before, state.seqno
359 );
360 self.shard_metrics.pubsub_push_diff_applied.inc();
361 } else {
362 debug!(
363 "failed to apply pushed diff {}. seqno {} vs diff {}",
364 state.shard_id, seqno_before, diff.seqno
365 );
366 if diff.seqno <= seqno_before {
367 self.shard_metrics.pubsub_push_diff_not_applied_stale.inc();
368 } else {
369 self.shard_metrics
370 .pubsub_push_diff_not_applied_out_of_order
371 .inc();
372 }
373 }
374 })
375 }
376}
377
378#[derive(Debug)]
389pub struct StateCache {
390 cfg: Arc<PersistConfig>,
391 pub(crate) metrics: Arc<Metrics>,
392 states: Arc<std::sync::Mutex<BTreeMap<ShardId, Arc<OnceCell<Weak<dyn DynState>>>>>>,
393 pubsub_sender: Arc<dyn PubSubSender>,
394}
395
396#[derive(Debug)]
397enum StateCacheInit {
398 Init(Arc<dyn DynState>),
399 NeedInit(Arc<OnceCell<Weak<dyn DynState>>>),
400}
401
402impl StateCache {
403 pub fn new(
405 cfg: &PersistConfig,
406 metrics: Arc<Metrics>,
407 pubsub_sender: Arc<dyn PubSubSender>,
408 ) -> Self {
409 StateCache {
410 cfg: Arc::new(cfg.clone()),
411 metrics,
412 states: Default::default(),
413 pubsub_sender,
414 }
415 }
416
417 #[cfg(test)]
418 pub(crate) fn new_no_metrics() -> Self {
419 Self::new(
420 &PersistConfig::new_for_tests(),
421 Arc::new(Metrics::new(
422 &PersistConfig::new_for_tests(),
423 &MetricsRegistry::new(),
424 )),
425 Arc::new(crate::rpc::NoopPubSubSender),
426 )
427 }
428
429 pub(crate) async fn get<K, V, T, D, F, InitFn>(
430 &self,
431 shard_id: ShardId,
432 mut init_fn: InitFn,
433 diagnostics: &Diagnostics,
434 ) -> Result<Arc<LockingTypedState<K, V, T, D>>, Box<CodecMismatch>>
435 where
436 K: Debug + Codec,
437 V: Debug + Codec,
438 T: Timestamp + Lattice + Codec64 + Sync,
439 D: Semigroup + Codec64,
440 F: Future<Output = Result<TypedState<K, V, T, D>, Box<CodecMismatch>>>,
441 InitFn: FnMut() -> F,
442 {
443 loop {
444 let init = {
445 let mut states = self.states.lock().expect("lock poisoned");
446 let state = states.entry(shard_id).or_default();
447 match state.get() {
448 Some(once_val) => match once_val.upgrade() {
449 Some(x) => StateCacheInit::Init(x),
450 None => {
451 *state = Arc::new(OnceCell::new());
455 StateCacheInit::NeedInit(Arc::clone(state))
456 }
457 },
458 None => StateCacheInit::NeedInit(Arc::clone(state)),
459 }
460 };
461
462 let state = match init {
463 StateCacheInit::Init(x) => x,
464 StateCacheInit::NeedInit(init_once) => {
465 let mut did_init: Option<Arc<LockingTypedState<K, V, T, D>>> = None;
466 let state = init_once
467 .get_or_try_init::<Box<CodecMismatch>, _, _>(|| async {
468 let init_res = init_fn().await;
469 let state = Arc::new(LockingTypedState::new(
470 shard_id,
471 init_res?,
472 Arc::clone(&self.metrics),
473 Arc::clone(&self.cfg),
474 Arc::clone(&self.pubsub_sender).subscribe(&shard_id),
475 diagnostics,
476 ));
477 let ret = Arc::downgrade(&state);
478 did_init = Some(state);
479 let ret: Weak<dyn DynState> = ret;
480 Ok(ret)
481 })
482 .await?;
483 if let Some(x) = did_init {
484 return Ok(x);
488 }
489 let Some(state) = state.upgrade() else {
490 continue;
497 };
498 state
499 }
500 };
501
502 match Arc::clone(&state)
503 .as_any()
504 .downcast::<LockingTypedState<K, V, T, D>>()
505 {
506 Ok(x) => return Ok(x),
507 Err(_) => {
508 return Err(Box::new(CodecMismatch {
509 requested: (
510 K::codec_name(),
511 V::codec_name(),
512 T::codec_name(),
513 D::codec_name(),
514 Some(CodecConcreteType(std::any::type_name::<(K, V, T, D)>())),
515 ),
516 actual: state.codecs(),
517 }));
518 }
519 }
520 }
521 }
522
523 pub(crate) fn get_state_weak(&self, shard_id: &ShardId) -> Option<Weak<dyn DynState>> {
524 self.states
525 .lock()
526 .expect("lock")
527 .get(shard_id)
528 .and_then(|x| x.get())
529 .map(Weak::clone)
530 }
531
532 #[cfg(test)]
533 fn get_cached(&self, shard_id: &ShardId) -> Option<Arc<dyn DynState>> {
534 self.states
535 .lock()
536 .expect("lock")
537 .get(shard_id)
538 .and_then(|x| x.get())
539 .and_then(|x| x.upgrade())
540 }
541
542 #[cfg(test)]
543 fn initialized_count(&self) -> usize {
544 self.states
545 .lock()
546 .expect("lock")
547 .values()
548 .filter(|x| x.initialized())
549 .count()
550 }
551
552 #[cfg(test)]
553 fn strong_count(&self) -> usize {
554 self.states
555 .lock()
556 .expect("lock")
557 .values()
558 .filter(|x| x.get().map_or(false, |x| x.upgrade().is_some()))
559 .count()
560 }
561}
562
563pub(crate) struct LockingTypedState<K, V, T, D> {
567 shard_id: ShardId,
568 state: RwLock<TypedState<K, V, T, D>>,
569 notifier: StateWatchNotifier,
570 cfg: Arc<PersistConfig>,
571 metrics: Arc<Metrics>,
572 shard_metrics: Arc<ShardMetrics>,
573 schema_cache: Arc<dyn Any + Send + Sync>,
576 _subscription_token: Arc<ShardSubscriptionToken>,
577}
578
579impl<K, V, T: Debug, D> Debug for LockingTypedState<K, V, T, D> {
580 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
581 let LockingTypedState {
582 shard_id,
583 state,
584 notifier,
585 cfg: _cfg,
586 metrics: _metrics,
587 shard_metrics: _shard_metrics,
588 schema_cache: _schema_cache,
589 _subscription_token,
590 } = self;
591 f.debug_struct("LockingTypedState")
592 .field("shard_id", shard_id)
593 .field("state", state)
594 .field("notifier", notifier)
595 .finish()
596 }
597}
598
599impl<K: Codec, V: Codec, T, D> LockingTypedState<K, V, T, D> {
600 fn new(
601 shard_id: ShardId,
602 initial_state: TypedState<K, V, T, D>,
603 metrics: Arc<Metrics>,
604 cfg: Arc<PersistConfig>,
605 subscription_token: Arc<ShardSubscriptionToken>,
606 diagnostics: &Diagnostics,
607 ) -> Self {
608 Self {
609 shard_id,
610 notifier: StateWatchNotifier::new(Arc::clone(&metrics)),
611 state: RwLock::new(initial_state),
612 cfg: Arc::clone(&cfg),
613 shard_metrics: metrics.shards.shard(&shard_id, &diagnostics.shard_name),
614 schema_cache: Arc::new(SchemaCacheMaps::<K, V>::new(&metrics.schema)),
615 metrics,
616 _subscription_token: subscription_token,
617 }
618 }
619
620 pub(crate) fn schema_cache(&self) -> Arc<SchemaCacheMaps<K, V>> {
621 Arc::clone(&self.schema_cache)
622 .downcast::<SchemaCacheMaps<K, V>>()
623 .expect("K and V match")
624 }
625}
626
627impl<K, V, T, D> LockingTypedState<K, V, T, D> {
628 pub(crate) fn shard_id(&self) -> &ShardId {
629 &self.shard_id
630 }
631
632 pub(crate) fn read_lock<R, F: FnMut(&TypedState<K, V, T, D>) -> R>(
633 &self,
634 metrics: &LockMetrics,
635 mut f: F,
636 ) -> R {
637 metrics.acquire_count.inc();
638 let state = match self.state.try_read() {
639 Ok(x) => x,
640 Err(TryLockError::WouldBlock) => {
641 metrics.blocking_acquire_count.inc();
642 let start = Instant::now();
643 let state = self.state.read().expect("lock poisoned");
644 metrics
645 .blocking_seconds
646 .inc_by(start.elapsed().as_secs_f64());
647 state
648 }
649 Err(TryLockError::Poisoned(err)) => panic!("state read lock poisoned: {}", err),
650 };
651 f(&state)
652 }
653
654 pub(crate) fn write_lock<R, F: FnOnce(&mut TypedState<K, V, T, D>) -> R>(
655 &self,
656 metrics: &LockMetrics,
657 f: F,
658 ) -> R {
659 metrics.acquire_count.inc();
660 let mut state = match self.state.try_write() {
661 Ok(x) => x,
662 Err(TryLockError::WouldBlock) => {
663 metrics.blocking_acquire_count.inc();
664 let start = Instant::now();
665 let state = self.state.write().expect("lock poisoned");
666 metrics
667 .blocking_seconds
668 .inc_by(start.elapsed().as_secs_f64());
669 state
670 }
671 Err(TryLockError::Poisoned(err)) => panic!("state read lock poisoned: {}", err),
672 };
673 let seqno_before = state.seqno;
674 let ret = f(&mut state);
675 let seqno_after = state.seqno;
676 debug_assert!(seqno_after >= seqno_before);
677 if seqno_after > seqno_before {
678 self.notifier.notify(seqno_after);
679 }
680 drop(state);
683 ret
684 }
685
686 pub(crate) fn notifier(&self) -> &StateWatchNotifier {
687 &self.notifier
688 }
689}
690
691#[cfg(test)]
692mod tests {
693 use std::ops::Deref;
694 use std::str::FromStr;
695 use std::sync::atomic::{AtomicBool, Ordering};
696
697 use futures::stream::{FuturesUnordered, StreamExt};
698 use mz_build_info::DUMMY_BUILD_INFO;
699 use mz_ore::task::spawn;
700 use mz_ore::{assert_err, assert_none};
701
702 use super::*;
703
704 #[mz_ore::test(tokio::test)]
705 #[cfg_attr(miri, ignore)] async fn client_cache() {
707 let cache = PersistClientCache::new(
708 PersistConfig::new_for_tests(),
709 &MetricsRegistry::new(),
710 |_, _| PubSubClientConnection::noop(),
711 );
712 assert_eq!(cache.blob_by_uri.lock().await.len(), 0);
713 assert_eq!(cache.consensus_by_uri.lock().await.len(), 0);
714
715 let _ = cache
717 .open(PersistLocation {
718 blob_uri: SensitiveUrl::from_str("mem://blob_zero").expect("invalid URL"),
719 consensus_uri: SensitiveUrl::from_str("mem://consensus_zero").expect("invalid URL"),
720 })
721 .await
722 .expect("failed to open location");
723 assert_eq!(cache.blob_by_uri.lock().await.len(), 1);
724 assert_eq!(cache.consensus_by_uri.lock().await.len(), 1);
725
726 let _ = cache
729 .open(PersistLocation {
730 blob_uri: SensitiveUrl::from_str("mem://blob_one").expect("invalid URL"),
731 consensus_uri: SensitiveUrl::from_str("mem://consensus_zero").expect("invalid URL"),
732 })
733 .await
734 .expect("failed to open location");
735 assert_eq!(cache.blob_by_uri.lock().await.len(), 2);
736 assert_eq!(cache.consensus_by_uri.lock().await.len(), 1);
737
738 let _ = cache
740 .open(PersistLocation {
741 blob_uri: SensitiveUrl::from_str("mem://blob_one").expect("invalid URL"),
742 consensus_uri: SensitiveUrl::from_str("mem://consensus_one").expect("invalid URL"),
743 })
744 .await
745 .expect("failed to open location");
746 assert_eq!(cache.blob_by_uri.lock().await.len(), 2);
747 assert_eq!(cache.consensus_by_uri.lock().await.len(), 2);
748
749 let _ = cache
751 .open(PersistLocation {
752 blob_uri: SensitiveUrl::from_str("mem://blob_one?foo").expect("invalid URL"),
753 consensus_uri: SensitiveUrl::from_str("mem://consensus_one/bar")
754 .expect("invalid URL"),
755 })
756 .await
757 .expect("failed to open location");
758 assert_eq!(cache.blob_by_uri.lock().await.len(), 3);
759 assert_eq!(cache.consensus_by_uri.lock().await.len(), 3);
760
761 let _ = cache
763 .open(PersistLocation {
764 blob_uri: SensitiveUrl::from_str("mem://user@blob_one").expect("invalid URL"),
765 consensus_uri: SensitiveUrl::from_str("mem://@consensus_one:123")
766 .expect("invalid URL"),
767 })
768 .await
769 .expect("failed to open location");
770 assert_eq!(cache.blob_by_uri.lock().await.len(), 4);
771 assert_eq!(cache.consensus_by_uri.lock().await.len(), 4);
772 }
773
774 #[mz_ore::test(tokio::test)]
775 #[cfg_attr(miri, ignore)] async fn state_cache() {
777 mz_ore::test::init_logging();
778 fn new_state<K, V, T, D>(shard_id: ShardId) -> TypedState<K, V, T, D>
779 where
780 K: Codec,
781 V: Codec,
782 T: Timestamp + Lattice + Codec64,
783 D: Codec64,
784 {
785 TypedState::new(
786 DUMMY_BUILD_INFO.semver_version(),
787 shard_id,
788 "host".into(),
789 0,
790 )
791 }
792 fn assert_same<K, V, T, D>(
793 state1: &LockingTypedState<K, V, T, D>,
794 state2: &LockingTypedState<K, V, T, D>,
795 ) {
796 let pointer1 = format!("{:p}", state1.state.read().expect("lock").deref());
797 let pointer2 = format!("{:p}", state2.state.read().expect("lock").deref());
798 assert_eq!(pointer1, pointer2);
799 }
800
801 let s1 = ShardId::new();
802 let states = Arc::new(StateCache::new_no_metrics());
803
804 assert_eq!(states.states.lock().expect("lock").len(), 0);
806
807 let s = Arc::clone(&states);
809 let res = spawn(|| "test", async move {
810 s.get::<(), (), u64, i64, _, _>(
811 s1,
812 || async { panic!("forced panic") },
813 &Diagnostics::for_tests(),
814 )
815 .await
816 })
817 .await;
818 assert_err!(res);
819 assert_eq!(states.initialized_count(), 0);
820
821 let res = states
823 .get::<(), (), u64, i64, _, _>(
824 s1,
825 || async {
826 Err(Box::new(CodecMismatch {
827 requested: ("".into(), "".into(), "".into(), "".into(), None),
828 actual: ("".into(), "".into(), "".into(), "".into(), None),
829 }))
830 },
831 &Diagnostics::for_tests(),
832 )
833 .await;
834 assert_err!(res);
835 assert_eq!(states.initialized_count(), 0);
836
837 let did_work = Arc::new(AtomicBool::new(false));
839 let s1_state1 = states
840 .get::<(), (), u64, i64, _, _>(
841 s1,
842 || {
843 let did_work = Arc::clone(&did_work);
844 async move {
845 did_work.store(true, Ordering::SeqCst);
846 Ok(new_state(s1))
847 }
848 },
849 &Diagnostics::for_tests(),
850 )
851 .await
852 .expect("should successfully initialize");
853 assert_eq!(did_work.load(Ordering::SeqCst), true);
854 assert_eq!(states.initialized_count(), 1);
855 assert_eq!(states.strong_count(), 1);
856
857 let did_work = Arc::new(AtomicBool::new(false));
859 let s1_state2 = states
860 .get::<(), (), u64, i64, _, _>(
861 s1,
862 || {
863 let did_work = Arc::clone(&did_work);
864 async move {
865 did_work.store(true, Ordering::SeqCst);
866 did_work.store(true, Ordering::SeqCst);
867 Ok(new_state(s1))
868 }
869 },
870 &Diagnostics::for_tests(),
871 )
872 .await
873 .expect("should successfully initialize");
874 assert_eq!(did_work.load(Ordering::SeqCst), false);
875 assert_eq!(states.initialized_count(), 1);
876 assert_eq!(states.strong_count(), 1);
877 assert_same(&s1_state1, &s1_state2);
878
879 let did_work = Arc::new(AtomicBool::new(false));
881 let res = states
882 .get::<String, (), u64, i64, _, _>(
883 s1,
884 || {
885 let did_work = Arc::clone(&did_work);
886 async move {
887 did_work.store(true, Ordering::SeqCst);
888 Ok(new_state(s1))
889 }
890 },
891 &Diagnostics::for_tests(),
892 )
893 .await;
894 assert_eq!(did_work.load(Ordering::SeqCst), false);
895 assert_eq!(
896 format!("{}", res.expect_err("types shouldn't match")),
897 "requested codecs (\"String\", \"()\", \"u64\", \"i64\", Some(CodecConcreteType(\"(alloc::string::String, (), u64, i64)\"))) did not match ones in durable storage (\"()\", \"()\", \"u64\", \"i64\", Some(CodecConcreteType(\"((), (), u64, i64)\")))"
898 );
899 assert_eq!(states.initialized_count(), 1);
900 assert_eq!(states.strong_count(), 1);
901
902 let s2 = ShardId::new();
904 let s2_state1 = states
905 .get::<String, (), u64, i64, _, _>(
906 s2,
907 || async { Ok(new_state(s2)) },
908 &Diagnostics::for_tests(),
909 )
910 .await
911 .expect("should successfully initialize");
912 assert_eq!(states.initialized_count(), 2);
913 assert_eq!(states.strong_count(), 2);
914 let s2_state2 = states
915 .get::<String, (), u64, i64, _, _>(
916 s2,
917 || async { Ok(new_state(s2)) },
918 &Diagnostics::for_tests(),
919 )
920 .await
921 .expect("should successfully initialize");
922 assert_same(&s2_state1, &s2_state2);
923
924 drop(s1_state1);
927 assert_eq!(states.strong_count(), 2);
928 drop(s1_state2);
929 assert_eq!(states.strong_count(), 1);
930 assert_eq!(states.initialized_count(), 2);
931 assert_none!(states.get_cached(&s1));
932
933 let s1_state1 = states
935 .get::<(), (), u64, i64, _, _>(
936 s1,
937 || async { Ok(new_state(s1)) },
938 &Diagnostics::for_tests(),
939 )
940 .await
941 .expect("should successfully initialize");
942 assert_eq!(states.initialized_count(), 2);
943 assert_eq!(states.strong_count(), 2);
944 drop(s1_state1);
945 assert_eq!(states.strong_count(), 1);
946 }
947
948 #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
949 #[cfg_attr(miri, ignore)] async fn state_cache_concurrency() {
951 mz_ore::test::init_logging();
952
953 const COUNT: usize = 1000;
954 let id = ShardId::new();
955 let cache = StateCache::new_no_metrics();
956 let diagnostics = Diagnostics::for_tests();
957
958 let mut futures = (0..COUNT)
959 .map(|_| {
960 cache.get::<(), (), u64, i64, _, _>(
961 id,
962 || async {
963 Ok(TypedState::new(
964 DUMMY_BUILD_INFO.semver_version(),
965 id,
966 "host".into(),
967 0,
968 ))
969 },
970 &diagnostics,
971 )
972 })
973 .collect::<FuturesUnordered<_>>();
974
975 for _ in 0..COUNT {
976 let _ = futures.next().await.unwrap();
977 }
978 }
979}