1use std::any::Any;
13use std::collections::BTreeMap;
14use std::collections::btree_map::Entry;
15use std::fmt::Debug;
16use std::future::Future;
17use std::sync::{Arc, RwLock, TryLockError, Weak};
18use std::time::{Duration, Instant};
19
20use differential_dataflow::difference::Semigroup;
21use differential_dataflow::lattice::Lattice;
22use mz_ore::instrument;
23use mz_ore::metrics::MetricsRegistry;
24use mz_ore::task::{AbortOnDropHandle, JoinHandle};
25use mz_ore::url::SensitiveUrl;
26use mz_persist::cfg::{BlobConfig, ConsensusConfig};
27use mz_persist::location::{
28 BLOB_GET_LIVENESS_KEY, Blob, CONSENSUS_HEAD_LIVENESS_KEY, Consensus, ExternalError, Tasked,
29 VersionedData,
30};
31use mz_persist_types::{Codec, Codec64};
32use timely::progress::Timestamp;
33use tokio::sync::{Mutex, OnceCell};
34use tracing::debug;
35
36use crate::async_runtime::IsolatedRuntime;
37use crate::error::{CodecConcreteType, CodecMismatch};
38use crate::internal::cache::BlobMemCache;
39use crate::internal::machine::retry_external;
40use crate::internal::metrics::{LockMetrics, Metrics, MetricsBlob, MetricsConsensus, ShardMetrics};
41use crate::internal::state::TypedState;
42use crate::internal::watch::StateWatchNotifier;
43use crate::rpc::{PubSubClientConnection, PubSubSender, ShardSubscriptionToken};
44use crate::schema::SchemaCacheMaps;
45use crate::{Diagnostics, PersistClient, PersistConfig, PersistLocation, ShardId};
46
47#[derive(Debug)]
56pub struct PersistClientCache {
57 pub cfg: PersistConfig,
59 pub(crate) metrics: Arc<Metrics>,
60 blob_by_uri: Mutex<BTreeMap<SensitiveUrl, (RttLatencyTask, Arc<dyn Blob>)>>,
61 consensus_by_uri: Mutex<BTreeMap<SensitiveUrl, (RttLatencyTask, Arc<dyn Consensus>)>>,
62 isolated_runtime: Arc<IsolatedRuntime>,
63 pub(crate) state_cache: Arc<StateCache>,
64 pubsub_sender: Arc<dyn PubSubSender>,
65 _pubsub_receiver_task: JoinHandle<()>,
66}
67
68#[derive(Debug)]
69struct RttLatencyTask(#[allow(dead_code)] AbortOnDropHandle<()>);
70
71impl PersistClientCache {
72 pub fn new<F>(cfg: PersistConfig, registry: &MetricsRegistry, pubsub: F) -> Self
74 where
75 F: FnOnce(&PersistConfig, Arc<Metrics>) -> PubSubClientConnection,
76 {
77 let metrics = Arc::new(Metrics::new(&cfg, registry));
78 let pubsub_client = pubsub(&cfg, Arc::clone(&metrics));
79
80 let state_cache = Arc::new(StateCache::new(
81 &cfg,
82 Arc::clone(&metrics),
83 Arc::clone(&pubsub_client.sender),
84 ));
85 let _pubsub_receiver_task = crate::rpc::subscribe_state_cache_to_pubsub(
86 Arc::clone(&state_cache),
87 pubsub_client.receiver,
88 );
89 let isolated_runtime =
90 IsolatedRuntime::new(registry, Some(cfg.isolated_runtime_worker_threads));
91
92 PersistClientCache {
93 cfg,
94 metrics,
95 blob_by_uri: Mutex::new(BTreeMap::new()),
96 consensus_by_uri: Mutex::new(BTreeMap::new()),
97 isolated_runtime: Arc::new(isolated_runtime),
98 state_cache,
99 pubsub_sender: pubsub_client.sender,
100 _pubsub_receiver_task,
101 }
102 }
103
104 pub fn new_no_metrics() -> Self {
107 Self::new(
108 PersistConfig::new_for_tests(),
109 &MetricsRegistry::new(),
110 |_, _| PubSubClientConnection::noop(),
111 )
112 }
113
114 pub fn cfg(&self) -> &PersistConfig {
116 &self.cfg
117 }
118
119 pub fn metrics(&self) -> &Arc<Metrics> {
121 &self.metrics
122 }
123
124 pub fn shard_metrics(&self, shard_id: &ShardId, name: &str) -> Arc<ShardMetrics> {
126 self.metrics.shards.shard(shard_id, name)
127 }
128
129 pub fn clear_state_cache(&mut self) {
133 self.state_cache = Arc::new(StateCache::new(
134 &self.cfg,
135 Arc::clone(&self.metrics),
136 Arc::clone(&self.pubsub_sender),
137 ))
138 }
139
140 #[instrument(level = "debug")]
145 pub async fn open(&self, location: PersistLocation) -> Result<PersistClient, ExternalError> {
146 let blob = self.open_blob(location.blob_uri).await?;
147 let consensus = self.open_consensus(location.consensus_uri).await?;
148 PersistClient::new(
149 self.cfg.clone(),
150 blob,
151 consensus,
152 Arc::clone(&self.metrics),
153 Arc::clone(&self.isolated_runtime),
154 Arc::clone(&self.state_cache),
155 Arc::clone(&self.pubsub_sender),
156 )
157 }
158
159 const PROMETHEUS_SCRAPE_INTERVAL: Duration = Duration::from_secs(60);
161
162 async fn open_consensus(
163 &self,
164 consensus_uri: SensitiveUrl,
165 ) -> Result<Arc<dyn Consensus>, ExternalError> {
166 let mut consensus_by_uri = self.consensus_by_uri.lock().await;
167 let consensus = match consensus_by_uri.entry(consensus_uri) {
168 Entry::Occupied(x) => Arc::clone(&x.get().1),
169 Entry::Vacant(x) => {
170 let consensus = ConsensusConfig::try_from(
173 x.key(),
174 Box::new(self.cfg.clone()),
175 self.metrics.postgres_consensus.clone(),
176 Arc::clone(&self.cfg().configs),
177 )?;
178 let consensus =
179 retry_external(&self.metrics.retries.external.consensus_open, || {
180 consensus.clone().open()
181 })
182 .await;
183 let consensus =
184 Arc::new(MetricsConsensus::new(consensus, Arc::clone(&self.metrics)));
185 let consensus = Arc::new(Tasked(consensus));
186 let task = consensus_rtt_latency_task(
187 Arc::clone(&consensus),
188 Arc::clone(&self.metrics),
189 Self::PROMETHEUS_SCRAPE_INTERVAL,
190 )
191 .await;
192 Arc::clone(
193 &x.insert((RttLatencyTask(task.abort_on_drop()), consensus))
194 .1,
195 )
196 }
197 };
198 Ok(consensus)
199 }
200
201 async fn open_blob(&self, blob_uri: SensitiveUrl) -> Result<Arc<dyn Blob>, ExternalError> {
202 let mut blob_by_uri = self.blob_by_uri.lock().await;
203 let blob = match blob_by_uri.entry(blob_uri) {
204 Entry::Occupied(x) => Arc::clone(&x.get().1),
205 Entry::Vacant(x) => {
206 let blob = BlobConfig::try_from(
209 x.key(),
210 Box::new(self.cfg.clone()),
211 self.metrics.s3_blob.clone(),
212 Arc::clone(&self.cfg.configs),
213 )
214 .await?;
215 let blob = retry_external(&self.metrics.retries.external.blob_open, || {
216 blob.clone().open()
217 })
218 .await;
219 let blob = Arc::new(MetricsBlob::new(blob, Arc::clone(&self.metrics)));
220 let blob = Arc::new(Tasked(blob));
221 let task = blob_rtt_latency_task(
222 Arc::clone(&blob),
223 Arc::clone(&self.metrics),
224 Self::PROMETHEUS_SCRAPE_INTERVAL,
225 )
226 .await;
227 let blob = BlobMemCache::new(&self.cfg, Arc::clone(&self.metrics), blob);
230 Arc::clone(&x.insert((RttLatencyTask(task.abort_on_drop()), blob)).1)
231 }
232 };
233 Ok(blob)
234 }
235}
236
237#[allow(clippy::unused_async)]
251async fn blob_rtt_latency_task(
252 blob: Arc<Tasked<MetricsBlob>>,
253 metrics: Arc<Metrics>,
254 measurement_interval: Duration,
255) -> JoinHandle<()> {
256 mz_ore::task::spawn(|| "persist::blob_rtt_latency", async move {
257 let mut next_measurement = tokio::time::Instant::now();
260 loop {
261 tokio::time::sleep_until(next_measurement).await;
262 let start = Instant::now();
263 match blob.get(BLOB_GET_LIVENESS_KEY).await {
264 Ok(_) => {
265 metrics.blob.rtt_latency.set(start.elapsed().as_secs_f64());
266 }
267 Err(_) => {
268 }
272 }
273 next_measurement = tokio::time::Instant::now() + measurement_interval;
274 }
275 })
276}
277
278#[allow(clippy::unused_async)]
292async fn consensus_rtt_latency_task(
293 consensus: Arc<Tasked<MetricsConsensus>>,
294 metrics: Arc<Metrics>,
295 measurement_interval: Duration,
296) -> JoinHandle<()> {
297 mz_ore::task::spawn(|| "persist::consensus_rtt_latency", async move {
298 let mut next_measurement = tokio::time::Instant::now();
301 loop {
302 tokio::time::sleep_until(next_measurement).await;
303 let start = Instant::now();
304 match consensus.head(CONSENSUS_HEAD_LIVENESS_KEY).await {
305 Ok(_) => {
306 metrics
307 .consensus
308 .rtt_latency
309 .set(start.elapsed().as_secs_f64());
310 }
311 Err(_) => {
312 }
316 }
317 next_measurement = tokio::time::Instant::now() + measurement_interval;
318 }
319 })
320}
321
322pub(crate) trait DynState: Debug + Send + Sync {
323 fn codecs(&self) -> (String, String, String, String, Option<CodecConcreteType>);
324 fn as_any(self: Arc<Self>) -> Arc<dyn Any + Send + Sync>;
325 fn push_diff(&self, diff: VersionedData);
326}
327
328impl<K, V, T, D> DynState for LockingTypedState<K, V, T, D>
329where
330 K: Codec,
331 V: Codec,
332 T: Timestamp + Lattice + Codec64 + Sync,
333 D: Codec64,
334{
335 fn codecs(&self) -> (String, String, String, String, Option<CodecConcreteType>) {
336 (
337 K::codec_name(),
338 V::codec_name(),
339 T::codec_name(),
340 D::codec_name(),
341 Some(CodecConcreteType(std::any::type_name::<(K, V, T, D)>())),
342 )
343 }
344
345 fn as_any(self: Arc<Self>) -> Arc<dyn Any + Send + Sync> {
346 self
347 }
348
349 fn push_diff(&self, diff: VersionedData) {
350 self.write_lock(&self.metrics.locks.applier_write, |state| {
351 let seqno_before = state.seqno;
352 state.apply_encoded_diffs(&self.cfg, &self.metrics, std::iter::once(&diff));
353 let seqno_after = state.seqno;
354 assert!(seqno_after >= seqno_before);
355
356 if seqno_before != seqno_after {
357 debug!(
358 "applied pushed diff {}. seqno {} -> {}.",
359 state.shard_id, seqno_before, state.seqno
360 );
361 self.shard_metrics.pubsub_push_diff_applied.inc();
362 } else {
363 debug!(
364 "failed to apply pushed diff {}. seqno {} vs diff {}",
365 state.shard_id, seqno_before, diff.seqno
366 );
367 if diff.seqno <= seqno_before {
368 self.shard_metrics.pubsub_push_diff_not_applied_stale.inc();
369 } else {
370 self.shard_metrics
371 .pubsub_push_diff_not_applied_out_of_order
372 .inc();
373 }
374 }
375 })
376 }
377}
378
379#[derive(Debug)]
390pub struct StateCache {
391 cfg: Arc<PersistConfig>,
392 pub(crate) metrics: Arc<Metrics>,
393 states: Arc<std::sync::Mutex<BTreeMap<ShardId, Arc<OnceCell<Weak<dyn DynState>>>>>>,
394 pubsub_sender: Arc<dyn PubSubSender>,
395}
396
397#[derive(Debug)]
398enum StateCacheInit {
399 Init(Arc<dyn DynState>),
400 NeedInit(Arc<OnceCell<Weak<dyn DynState>>>),
401}
402
403impl StateCache {
404 pub fn new(
406 cfg: &PersistConfig,
407 metrics: Arc<Metrics>,
408 pubsub_sender: Arc<dyn PubSubSender>,
409 ) -> Self {
410 StateCache {
411 cfg: Arc::new(cfg.clone()),
412 metrics,
413 states: Default::default(),
414 pubsub_sender,
415 }
416 }
417
418 #[cfg(test)]
419 pub(crate) fn new_no_metrics() -> Self {
420 Self::new(
421 &PersistConfig::new_for_tests(),
422 Arc::new(Metrics::new(
423 &PersistConfig::new_for_tests(),
424 &MetricsRegistry::new(),
425 )),
426 Arc::new(crate::rpc::NoopPubSubSender),
427 )
428 }
429
430 pub(crate) async fn get<K, V, T, D, F, InitFn>(
431 &self,
432 shard_id: ShardId,
433 mut init_fn: InitFn,
434 diagnostics: &Diagnostics,
435 ) -> Result<Arc<LockingTypedState<K, V, T, D>>, Box<CodecMismatch>>
436 where
437 K: Debug + Codec,
438 V: Debug + Codec,
439 T: Timestamp + Lattice + Codec64 + Sync,
440 D: Semigroup + Codec64,
441 F: Future<Output = Result<TypedState<K, V, T, D>, Box<CodecMismatch>>>,
442 InitFn: FnMut() -> F,
443 {
444 loop {
445 let init = {
446 let mut states = self.states.lock().expect("lock poisoned");
447 let state = states.entry(shard_id).or_default();
448 match state.get() {
449 Some(once_val) => match once_val.upgrade() {
450 Some(x) => StateCacheInit::Init(x),
451 None => {
452 *state = Arc::new(OnceCell::new());
456 StateCacheInit::NeedInit(Arc::clone(state))
457 }
458 },
459 None => StateCacheInit::NeedInit(Arc::clone(state)),
460 }
461 };
462
463 let state = match init {
464 StateCacheInit::Init(x) => x,
465 StateCacheInit::NeedInit(init_once) => {
466 let mut did_init: Option<Arc<LockingTypedState<K, V, T, D>>> = None;
467 let state = init_once
468 .get_or_try_init::<Box<CodecMismatch>, _, _>(|| async {
469 let init_res = init_fn().await;
470 let state = Arc::new(LockingTypedState::new(
471 shard_id,
472 init_res?,
473 Arc::clone(&self.metrics),
474 Arc::clone(&self.cfg),
475 Arc::clone(&self.pubsub_sender).subscribe(&shard_id),
476 diagnostics,
477 ));
478 let ret = Arc::downgrade(&state);
479 did_init = Some(state);
480 let ret: Weak<dyn DynState> = ret;
481 Ok(ret)
482 })
483 .await?;
484 if let Some(x) = did_init {
485 return Ok(x);
489 }
490 let Some(state) = state.upgrade() else {
491 continue;
498 };
499 state
500 }
501 };
502
503 match Arc::clone(&state)
504 .as_any()
505 .downcast::<LockingTypedState<K, V, T, D>>()
506 {
507 Ok(x) => return Ok(x),
508 Err(_) => {
509 return Err(Box::new(CodecMismatch {
510 requested: (
511 K::codec_name(),
512 V::codec_name(),
513 T::codec_name(),
514 D::codec_name(),
515 Some(CodecConcreteType(std::any::type_name::<(K, V, T, D)>())),
516 ),
517 actual: state.codecs(),
518 }));
519 }
520 }
521 }
522 }
523
524 pub(crate) fn get_state_weak(&self, shard_id: &ShardId) -> Option<Weak<dyn DynState>> {
525 self.states
526 .lock()
527 .expect("lock")
528 .get(shard_id)
529 .and_then(|x| x.get())
530 .map(Weak::clone)
531 }
532
533 #[cfg(test)]
534 fn get_cached(&self, shard_id: &ShardId) -> Option<Arc<dyn DynState>> {
535 self.states
536 .lock()
537 .expect("lock")
538 .get(shard_id)
539 .and_then(|x| x.get())
540 .and_then(|x| x.upgrade())
541 }
542
543 #[cfg(test)]
544 fn initialized_count(&self) -> usize {
545 self.states
546 .lock()
547 .expect("lock")
548 .values()
549 .filter(|x| x.initialized())
550 .count()
551 }
552
553 #[cfg(test)]
554 fn strong_count(&self) -> usize {
555 self.states
556 .lock()
557 .expect("lock")
558 .values()
559 .filter(|x| x.get().map_or(false, |x| x.upgrade().is_some()))
560 .count()
561 }
562}
563
564pub(crate) struct LockingTypedState<K, V, T, D> {
568 shard_id: ShardId,
569 state: RwLock<TypedState<K, V, T, D>>,
570 notifier: StateWatchNotifier,
571 cfg: Arc<PersistConfig>,
572 metrics: Arc<Metrics>,
573 shard_metrics: Arc<ShardMetrics>,
574 schema_cache: Arc<dyn Any + Send + Sync>,
577 _subscription_token: Arc<ShardSubscriptionToken>,
578}
579
580impl<K, V, T: Debug, D> Debug for LockingTypedState<K, V, T, D> {
581 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
582 let LockingTypedState {
583 shard_id,
584 state,
585 notifier,
586 cfg: _cfg,
587 metrics: _metrics,
588 shard_metrics: _shard_metrics,
589 schema_cache: _schema_cache,
590 _subscription_token,
591 } = self;
592 f.debug_struct("LockingTypedState")
593 .field("shard_id", shard_id)
594 .field("state", state)
595 .field("notifier", notifier)
596 .finish()
597 }
598}
599
600impl<K: Codec, V: Codec, T, D> LockingTypedState<K, V, T, D> {
601 fn new(
602 shard_id: ShardId,
603 initial_state: TypedState<K, V, T, D>,
604 metrics: Arc<Metrics>,
605 cfg: Arc<PersistConfig>,
606 subscription_token: Arc<ShardSubscriptionToken>,
607 diagnostics: &Diagnostics,
608 ) -> Self {
609 Self {
610 shard_id,
611 notifier: StateWatchNotifier::new(Arc::clone(&metrics)),
612 state: RwLock::new(initial_state),
613 cfg: Arc::clone(&cfg),
614 shard_metrics: metrics.shards.shard(&shard_id, &diagnostics.shard_name),
615 schema_cache: Arc::new(SchemaCacheMaps::<K, V>::new(&metrics.schema)),
616 metrics,
617 _subscription_token: subscription_token,
618 }
619 }
620
621 pub(crate) fn schema_cache(&self) -> Arc<SchemaCacheMaps<K, V>> {
622 Arc::clone(&self.schema_cache)
623 .downcast::<SchemaCacheMaps<K, V>>()
624 .expect("K and V match")
625 }
626}
627
628impl<K, V, T, D> LockingTypedState<K, V, T, D> {
629 pub(crate) fn shard_id(&self) -> &ShardId {
630 &self.shard_id
631 }
632
633 pub(crate) fn read_lock<R, F: FnMut(&TypedState<K, V, T, D>) -> R>(
634 &self,
635 metrics: &LockMetrics,
636 mut f: F,
637 ) -> R {
638 metrics.acquire_count.inc();
639 let state = match self.state.try_read() {
640 Ok(x) => x,
641 Err(TryLockError::WouldBlock) => {
642 metrics.blocking_acquire_count.inc();
643 let start = Instant::now();
644 let state = self.state.read().expect("lock poisoned");
645 metrics
646 .blocking_seconds
647 .inc_by(start.elapsed().as_secs_f64());
648 state
649 }
650 Err(TryLockError::Poisoned(err)) => panic!("state read lock poisoned: {}", err),
651 };
652 f(&state)
653 }
654
655 pub(crate) fn write_lock<R, F: FnOnce(&mut TypedState<K, V, T, D>) -> R>(
656 &self,
657 metrics: &LockMetrics,
658 f: F,
659 ) -> R {
660 metrics.acquire_count.inc();
661 let mut state = match self.state.try_write() {
662 Ok(x) => x,
663 Err(TryLockError::WouldBlock) => {
664 metrics.blocking_acquire_count.inc();
665 let start = Instant::now();
666 let state = self.state.write().expect("lock poisoned");
667 metrics
668 .blocking_seconds
669 .inc_by(start.elapsed().as_secs_f64());
670 state
671 }
672 Err(TryLockError::Poisoned(err)) => panic!("state read lock poisoned: {}", err),
673 };
674 let seqno_before = state.seqno;
675 let ret = f(&mut state);
676 let seqno_after = state.seqno;
677 debug_assert!(seqno_after >= seqno_before);
678 if seqno_after > seqno_before {
679 self.notifier.notify(seqno_after);
680 }
681 drop(state);
684 ret
685 }
686
687 pub(crate) fn notifier(&self) -> &StateWatchNotifier {
688 &self.notifier
689 }
690}
691
692#[cfg(test)]
693mod tests {
694 use std::ops::Deref;
695 use std::str::FromStr;
696 use std::sync::atomic::{AtomicBool, Ordering};
697
698 use futures::stream::{FuturesUnordered, StreamExt};
699 use mz_build_info::DUMMY_BUILD_INFO;
700 use mz_ore::task::spawn;
701 use mz_ore::{assert_err, assert_none};
702
703 use super::*;
704
705 #[mz_ore::test(tokio::test)]
706 #[cfg_attr(miri, ignore)] async fn client_cache() {
708 let cache = PersistClientCache::new(
709 PersistConfig::new_for_tests(),
710 &MetricsRegistry::new(),
711 |_, _| PubSubClientConnection::noop(),
712 );
713 assert_eq!(cache.blob_by_uri.lock().await.len(), 0);
714 assert_eq!(cache.consensus_by_uri.lock().await.len(), 0);
715
716 let _ = cache
718 .open(PersistLocation {
719 blob_uri: SensitiveUrl::from_str("mem://blob_zero").expect("invalid URL"),
720 consensus_uri: SensitiveUrl::from_str("mem://consensus_zero").expect("invalid URL"),
721 })
722 .await
723 .expect("failed to open location");
724 assert_eq!(cache.blob_by_uri.lock().await.len(), 1);
725 assert_eq!(cache.consensus_by_uri.lock().await.len(), 1);
726
727 let _ = cache
730 .open(PersistLocation {
731 blob_uri: SensitiveUrl::from_str("mem://blob_one").expect("invalid URL"),
732 consensus_uri: SensitiveUrl::from_str("mem://consensus_zero").expect("invalid URL"),
733 })
734 .await
735 .expect("failed to open location");
736 assert_eq!(cache.blob_by_uri.lock().await.len(), 2);
737 assert_eq!(cache.consensus_by_uri.lock().await.len(), 1);
738
739 let _ = cache
741 .open(PersistLocation {
742 blob_uri: SensitiveUrl::from_str("mem://blob_one").expect("invalid URL"),
743 consensus_uri: SensitiveUrl::from_str("mem://consensus_one").expect("invalid URL"),
744 })
745 .await
746 .expect("failed to open location");
747 assert_eq!(cache.blob_by_uri.lock().await.len(), 2);
748 assert_eq!(cache.consensus_by_uri.lock().await.len(), 2);
749
750 let _ = cache
752 .open(PersistLocation {
753 blob_uri: SensitiveUrl::from_str("mem://blob_one?foo").expect("invalid URL"),
754 consensus_uri: SensitiveUrl::from_str("mem://consensus_one/bar")
755 .expect("invalid URL"),
756 })
757 .await
758 .expect("failed to open location");
759 assert_eq!(cache.blob_by_uri.lock().await.len(), 3);
760 assert_eq!(cache.consensus_by_uri.lock().await.len(), 3);
761
762 let _ = cache
764 .open(PersistLocation {
765 blob_uri: SensitiveUrl::from_str("mem://user@blob_one").expect("invalid URL"),
766 consensus_uri: SensitiveUrl::from_str("mem://@consensus_one:123")
767 .expect("invalid URL"),
768 })
769 .await
770 .expect("failed to open location");
771 assert_eq!(cache.blob_by_uri.lock().await.len(), 4);
772 assert_eq!(cache.consensus_by_uri.lock().await.len(), 4);
773 }
774
775 #[mz_ore::test(tokio::test)]
776 #[cfg_attr(miri, ignore)] async fn state_cache() {
778 mz_ore::test::init_logging();
779 fn new_state<K, V, T, D>(shard_id: ShardId) -> TypedState<K, V, T, D>
780 where
781 K: Codec,
782 V: Codec,
783 T: Timestamp + Lattice + Codec64,
784 D: Codec64,
785 {
786 TypedState::new(
787 DUMMY_BUILD_INFO.semver_version(),
788 shard_id,
789 "host".into(),
790 0,
791 )
792 }
793 fn assert_same<K, V, T, D>(
794 state1: &LockingTypedState<K, V, T, D>,
795 state2: &LockingTypedState<K, V, T, D>,
796 ) {
797 let pointer1 = format!("{:p}", state1.state.read().expect("lock").deref());
798 let pointer2 = format!("{:p}", state2.state.read().expect("lock").deref());
799 assert_eq!(pointer1, pointer2);
800 }
801
802 let s1 = ShardId::new();
803 let states = Arc::new(StateCache::new_no_metrics());
804
805 assert_eq!(states.states.lock().expect("lock").len(), 0);
807
808 let s = Arc::clone(&states);
810 let res = spawn(|| "test", async move {
811 s.get::<(), (), u64, i64, _, _>(
812 s1,
813 || async { panic!("forced panic") },
814 &Diagnostics::for_tests(),
815 )
816 .await
817 })
818 .await;
819 assert_err!(res);
820 assert_eq!(states.initialized_count(), 0);
821
822 let res = states
824 .get::<(), (), u64, i64, _, _>(
825 s1,
826 || async {
827 Err(Box::new(CodecMismatch {
828 requested: ("".into(), "".into(), "".into(), "".into(), None),
829 actual: ("".into(), "".into(), "".into(), "".into(), None),
830 }))
831 },
832 &Diagnostics::for_tests(),
833 )
834 .await;
835 assert_err!(res);
836 assert_eq!(states.initialized_count(), 0);
837
838 let did_work = Arc::new(AtomicBool::new(false));
840 let s1_state1 = states
841 .get::<(), (), u64, i64, _, _>(
842 s1,
843 || {
844 let did_work = Arc::clone(&did_work);
845 async move {
846 did_work.store(true, Ordering::SeqCst);
847 Ok(new_state(s1))
848 }
849 },
850 &Diagnostics::for_tests(),
851 )
852 .await
853 .expect("should successfully initialize");
854 assert_eq!(did_work.load(Ordering::SeqCst), true);
855 assert_eq!(states.initialized_count(), 1);
856 assert_eq!(states.strong_count(), 1);
857
858 let did_work = Arc::new(AtomicBool::new(false));
860 let s1_state2 = states
861 .get::<(), (), u64, i64, _, _>(
862 s1,
863 || {
864 let did_work = Arc::clone(&did_work);
865 async move {
866 did_work.store(true, Ordering::SeqCst);
867 did_work.store(true, Ordering::SeqCst);
868 Ok(new_state(s1))
869 }
870 },
871 &Diagnostics::for_tests(),
872 )
873 .await
874 .expect("should successfully initialize");
875 assert_eq!(did_work.load(Ordering::SeqCst), false);
876 assert_eq!(states.initialized_count(), 1);
877 assert_eq!(states.strong_count(), 1);
878 assert_same(&s1_state1, &s1_state2);
879
880 let did_work = Arc::new(AtomicBool::new(false));
882 let res = states
883 .get::<String, (), u64, i64, _, _>(
884 s1,
885 || {
886 let did_work = Arc::clone(&did_work);
887 async move {
888 did_work.store(true, Ordering::SeqCst);
889 Ok(new_state(s1))
890 }
891 },
892 &Diagnostics::for_tests(),
893 )
894 .await;
895 assert_eq!(did_work.load(Ordering::SeqCst), false);
896 assert_eq!(
897 format!("{}", res.expect_err("types shouldn't match")),
898 "requested codecs (\"String\", \"()\", \"u64\", \"i64\", Some(CodecConcreteType(\"(alloc::string::String, (), u64, i64)\"))) did not match ones in durable storage (\"()\", \"()\", \"u64\", \"i64\", Some(CodecConcreteType(\"((), (), u64, i64)\")))"
899 );
900 assert_eq!(states.initialized_count(), 1);
901 assert_eq!(states.strong_count(), 1);
902
903 let s2 = ShardId::new();
905 let s2_state1 = states
906 .get::<String, (), u64, i64, _, _>(
907 s2,
908 || async { Ok(new_state(s2)) },
909 &Diagnostics::for_tests(),
910 )
911 .await
912 .expect("should successfully initialize");
913 assert_eq!(states.initialized_count(), 2);
914 assert_eq!(states.strong_count(), 2);
915 let s2_state2 = states
916 .get::<String, (), u64, i64, _, _>(
917 s2,
918 || async { Ok(new_state(s2)) },
919 &Diagnostics::for_tests(),
920 )
921 .await
922 .expect("should successfully initialize");
923 assert_same(&s2_state1, &s2_state2);
924
925 drop(s1_state1);
928 assert_eq!(states.strong_count(), 2);
929 drop(s1_state2);
930 assert_eq!(states.strong_count(), 1);
931 assert_eq!(states.initialized_count(), 2);
932 assert_none!(states.get_cached(&s1));
933
934 let s1_state1 = states
936 .get::<(), (), u64, i64, _, _>(
937 s1,
938 || async { Ok(new_state(s1)) },
939 &Diagnostics::for_tests(),
940 )
941 .await
942 .expect("should successfully initialize");
943 assert_eq!(states.initialized_count(), 2);
944 assert_eq!(states.strong_count(), 2);
945 drop(s1_state1);
946 assert_eq!(states.strong_count(), 1);
947 }
948
949 #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
950 #[cfg_attr(miri, ignore)] async fn state_cache_concurrency() {
952 mz_ore::test::init_logging();
953
954 const COUNT: usize = 1000;
955 let id = ShardId::new();
956 let cache = StateCache::new_no_metrics();
957 let diagnostics = Diagnostics::for_tests();
958
959 let mut futures = (0..COUNT)
960 .map(|_| {
961 cache.get::<(), (), u64, i64, _, _>(
962 id,
963 || async {
964 Ok(TypedState::new(
965 DUMMY_BUILD_INFO.semver_version(),
966 id,
967 "host".into(),
968 0,
969 ))
970 },
971 &diagnostics,
972 )
973 })
974 .collect::<FuturesUnordered<_>>();
975
976 for _ in 0..COUNT {
977 let _ = futures.next().await.unwrap();
978 }
979 }
980}