1use std::any::Any;
13use std::collections::BTreeMap;
14use std::collections::btree_map::Entry;
15use std::fmt::Debug;
16use std::future::Future;
17use std::sync::{Arc, RwLock, TryLockError, Weak};
18use std::time::{Duration, Instant};
19
20use differential_dataflow::difference::Monoid;
21use differential_dataflow::lattice::Lattice;
22use mz_ore::instrument;
23use mz_ore::metrics::MetricsRegistry;
24use mz_ore::task::{AbortOnDropHandle, JoinHandle};
25use mz_ore::url::SensitiveUrl;
26use mz_persist::cfg::{BlobConfig, ConsensusConfig};
27use mz_persist::location::{
28 BLOB_GET_LIVENESS_KEY, Blob, CONSENSUS_HEAD_LIVENESS_KEY, Consensus, ExternalError, Tasked,
29 VersionedData,
30};
31use mz_persist_types::{Codec, Codec64};
32use timely::progress::Timestamp;
33use tokio::sync::{Mutex, OnceCell};
34use tracing::debug;
35
36use crate::async_runtime::IsolatedRuntime;
37use crate::error::{CodecConcreteType, CodecMismatch};
38use crate::internal::cache::BlobMemCache;
39use crate::internal::machine::retry_external;
40use crate::internal::metrics::{LockMetrics, Metrics, MetricsBlob, MetricsConsensus, ShardMetrics};
41use crate::internal::state::TypedState;
42use crate::internal::watch::StateWatchNotifier;
43use crate::rpc::{PubSubClientConnection, PubSubSender, ShardSubscriptionToken};
44use crate::schema::SchemaCacheMaps;
45use crate::{Diagnostics, PersistClient, PersistConfig, PersistLocation, ShardId};
46
47#[derive(Debug)]
56pub struct PersistClientCache {
57 pub cfg: PersistConfig,
59 pub(crate) metrics: Arc<Metrics>,
60 blob_by_uri: Mutex<BTreeMap<SensitiveUrl, (RttLatencyTask, Arc<dyn Blob>)>>,
61 consensus_by_uri: Mutex<BTreeMap<SensitiveUrl, (RttLatencyTask, Arc<dyn Consensus>)>>,
62 isolated_runtime: Arc<IsolatedRuntime>,
63 pub(crate) state_cache: Arc<StateCache>,
64 pubsub_sender: Arc<dyn PubSubSender>,
65 _pubsub_receiver_task: JoinHandle<()>,
66}
67
68#[derive(Debug)]
69struct RttLatencyTask(#[allow(dead_code)] AbortOnDropHandle<()>);
70
71impl PersistClientCache {
72 pub fn new<F>(cfg: PersistConfig, registry: &MetricsRegistry, pubsub: F) -> Self
74 where
75 F: FnOnce(&PersistConfig, Arc<Metrics>) -> PubSubClientConnection,
76 {
77 let metrics = Arc::new(Metrics::new(&cfg, registry));
78 let pubsub_client = pubsub(&cfg, Arc::clone(&metrics));
79
80 let state_cache = Arc::new(StateCache::new(
81 &cfg,
82 Arc::clone(&metrics),
83 Arc::clone(&pubsub_client.sender),
84 ));
85 let _pubsub_receiver_task = crate::rpc::subscribe_state_cache_to_pubsub(
86 Arc::clone(&state_cache),
87 pubsub_client.receiver,
88 );
89 let isolated_runtime =
90 IsolatedRuntime::new(registry, Some(cfg.isolated_runtime_worker_threads));
91
92 PersistClientCache {
93 cfg,
94 metrics,
95 blob_by_uri: Mutex::new(BTreeMap::new()),
96 consensus_by_uri: Mutex::new(BTreeMap::new()),
97 isolated_runtime: Arc::new(isolated_runtime),
98 state_cache,
99 pubsub_sender: pubsub_client.sender,
100 _pubsub_receiver_task,
101 }
102 }
103
104 pub fn new_no_metrics() -> Self {
107 Self::new(
108 PersistConfig::new_for_tests(),
109 &MetricsRegistry::new(),
110 |_, _| PubSubClientConnection::noop(),
111 )
112 }
113
114 #[cfg(feature = "turmoil")]
115 pub fn new_for_turmoil() -> Self {
120 use crate::rpc::NoopPubSubSender;
121
122 let cfg = PersistConfig::new_for_tests();
123 let metrics = Arc::new(Metrics::new(&cfg, &MetricsRegistry::new()));
124
125 let pubsub_sender: Arc<dyn PubSubSender> = Arc::new(NoopPubSubSender);
126 let _pubsub_receiver_task = mz_ore::task::spawn(|| "noop", async {});
127
128 let state_cache = Arc::new(StateCache::new(
129 &cfg,
130 Arc::clone(&metrics),
131 Arc::clone(&pubsub_sender),
132 ));
133 let isolated_runtime = IsolatedRuntime::new_disabled();
134
135 PersistClientCache {
136 cfg,
137 metrics,
138 blob_by_uri: Mutex::new(BTreeMap::new()),
139 consensus_by_uri: Mutex::new(BTreeMap::new()),
140 isolated_runtime: Arc::new(isolated_runtime),
141 state_cache,
142 pubsub_sender,
143 _pubsub_receiver_task,
144 }
145 }
146
147 pub fn cfg(&self) -> &PersistConfig {
149 &self.cfg
150 }
151
152 pub fn metrics(&self) -> &Arc<Metrics> {
154 &self.metrics
155 }
156
157 pub fn shard_metrics(&self, shard_id: &ShardId, name: &str) -> Arc<ShardMetrics> {
159 self.metrics.shards.shard(shard_id, name)
160 }
161
162 pub fn clear_state_cache(&mut self) {
166 self.state_cache = Arc::new(StateCache::new(
167 &self.cfg,
168 Arc::clone(&self.metrics),
169 Arc::clone(&self.pubsub_sender),
170 ))
171 }
172
173 #[instrument(level = "debug")]
178 pub async fn open(&self, location: PersistLocation) -> Result<PersistClient, ExternalError> {
179 let blob = self.open_blob(location.blob_uri).await?;
180 let consensus = self.open_consensus(location.consensus_uri).await?;
181 PersistClient::new(
182 self.cfg.clone(),
183 blob,
184 consensus,
185 Arc::clone(&self.metrics),
186 Arc::clone(&self.isolated_runtime),
187 Arc::clone(&self.state_cache),
188 Arc::clone(&self.pubsub_sender),
189 )
190 }
191
192 const PROMETHEUS_SCRAPE_INTERVAL: Duration = Duration::from_secs(60);
194
195 async fn open_consensus(
196 &self,
197 consensus_uri: SensitiveUrl,
198 ) -> Result<Arc<dyn Consensus>, ExternalError> {
199 let mut consensus_by_uri = self.consensus_by_uri.lock().await;
200 let consensus = match consensus_by_uri.entry(consensus_uri) {
201 Entry::Occupied(x) => Arc::clone(&x.get().1),
202 Entry::Vacant(x) => {
203 let consensus = ConsensusConfig::try_from(
206 x.key(),
207 Box::new(self.cfg.clone()),
208 self.metrics.postgres_consensus.clone(),
209 Arc::clone(&self.cfg().configs),
210 )?;
211 let consensus =
212 retry_external(&self.metrics.retries.external.consensus_open, || {
213 consensus.clone().open()
214 })
215 .await;
216 let consensus =
217 Arc::new(MetricsConsensus::new(consensus, Arc::clone(&self.metrics)));
218 let consensus = Arc::new(Tasked(consensus));
219 let task = consensus_rtt_latency_task(
220 Arc::clone(&consensus),
221 Arc::clone(&self.metrics),
222 Self::PROMETHEUS_SCRAPE_INTERVAL,
223 )
224 .await;
225 Arc::clone(
226 &x.insert((RttLatencyTask(task.abort_on_drop()), consensus))
227 .1,
228 )
229 }
230 };
231 Ok(consensus)
232 }
233
234 async fn open_blob(&self, blob_uri: SensitiveUrl) -> Result<Arc<dyn Blob>, ExternalError> {
235 let mut blob_by_uri = self.blob_by_uri.lock().await;
236 let blob = match blob_by_uri.entry(blob_uri) {
237 Entry::Occupied(x) => Arc::clone(&x.get().1),
238 Entry::Vacant(x) => {
239 let blob = BlobConfig::try_from(
242 x.key(),
243 Box::new(self.cfg.clone()),
244 self.metrics.s3_blob.clone(),
245 Arc::clone(&self.cfg.configs),
246 )
247 .await?;
248 let blob = retry_external(&self.metrics.retries.external.blob_open, || {
249 blob.clone().open()
250 })
251 .await;
252 let blob = Arc::new(MetricsBlob::new(blob, Arc::clone(&self.metrics)));
253 let blob = Arc::new(Tasked(blob));
254 let task = blob_rtt_latency_task(
255 Arc::clone(&blob),
256 Arc::clone(&self.metrics),
257 Self::PROMETHEUS_SCRAPE_INTERVAL,
258 )
259 .await;
260 let blob = BlobMemCache::new(&self.cfg, Arc::clone(&self.metrics), blob);
263 Arc::clone(&x.insert((RttLatencyTask(task.abort_on_drop()), blob)).1)
264 }
265 };
266 Ok(blob)
267 }
268}
269
270#[allow(clippy::unused_async)]
284async fn blob_rtt_latency_task(
285 blob: Arc<Tasked<MetricsBlob>>,
286 metrics: Arc<Metrics>,
287 measurement_interval: Duration,
288) -> JoinHandle<()> {
289 mz_ore::task::spawn(|| "persist::blob_rtt_latency", async move {
290 let mut next_measurement = tokio::time::Instant::now();
293 loop {
294 tokio::time::sleep_until(next_measurement).await;
295 let start = Instant::now();
296 match blob.get(BLOB_GET_LIVENESS_KEY).await {
297 Ok(_) => {
298 metrics.blob.rtt_latency.set(start.elapsed().as_secs_f64());
299 }
300 Err(_) => {
301 }
305 }
306 next_measurement = tokio::time::Instant::now() + measurement_interval;
307 }
308 })
309}
310
311#[allow(clippy::unused_async)]
325async fn consensus_rtt_latency_task(
326 consensus: Arc<Tasked<MetricsConsensus>>,
327 metrics: Arc<Metrics>,
328 measurement_interval: Duration,
329) -> JoinHandle<()> {
330 mz_ore::task::spawn(|| "persist::consensus_rtt_latency", async move {
331 let mut next_measurement = tokio::time::Instant::now();
334 loop {
335 tokio::time::sleep_until(next_measurement).await;
336 let start = Instant::now();
337 match consensus.head(CONSENSUS_HEAD_LIVENESS_KEY).await {
338 Ok(_) => {
339 metrics
340 .consensus
341 .rtt_latency
342 .set(start.elapsed().as_secs_f64());
343 }
344 Err(_) => {
345 }
349 }
350 next_measurement = tokio::time::Instant::now() + measurement_interval;
351 }
352 })
353}
354
355pub(crate) trait DynState: Debug + Send + Sync {
356 fn codecs(&self) -> (String, String, String, String, Option<CodecConcreteType>);
357 fn as_any(self: Arc<Self>) -> Arc<dyn Any + Send + Sync>;
358 fn push_diff(&self, diff: VersionedData);
359}
360
361impl<K, V, T, D> DynState for LockingTypedState<K, V, T, D>
362where
363 K: Codec,
364 V: Codec,
365 T: Timestamp + Lattice + Codec64 + Sync,
366 D: Codec64,
367{
368 fn codecs(&self) -> (String, String, String, String, Option<CodecConcreteType>) {
369 (
370 K::codec_name(),
371 V::codec_name(),
372 T::codec_name(),
373 D::codec_name(),
374 Some(CodecConcreteType(std::any::type_name::<(K, V, T, D)>())),
375 )
376 }
377
378 fn as_any(self: Arc<Self>) -> Arc<dyn Any + Send + Sync> {
379 self
380 }
381
382 fn push_diff(&self, diff: VersionedData) {
383 self.write_lock(&self.metrics.locks.applier_write, |state| {
384 let seqno_before = state.seqno;
385 state.apply_encoded_diffs(&self.cfg, &self.metrics, std::iter::once(&diff));
386 let seqno_after = state.seqno;
387 assert!(seqno_after >= seqno_before);
388
389 if seqno_before != seqno_after {
390 debug!(
391 "applied pushed diff {}. seqno {} -> {}.",
392 state.shard_id, seqno_before, state.seqno
393 );
394 self.shard_metrics.pubsub_push_diff_applied.inc();
395 } else {
396 debug!(
397 "failed to apply pushed diff {}. seqno {} vs diff {}",
398 state.shard_id, seqno_before, diff.seqno
399 );
400 if diff.seqno <= seqno_before {
401 self.shard_metrics.pubsub_push_diff_not_applied_stale.inc();
402 } else {
403 self.shard_metrics
404 .pubsub_push_diff_not_applied_out_of_order
405 .inc();
406 }
407 }
408 })
409 }
410}
411
412#[derive(Debug)]
423pub struct StateCache {
424 cfg: Arc<PersistConfig>,
425 pub(crate) metrics: Arc<Metrics>,
426 states: Arc<std::sync::Mutex<BTreeMap<ShardId, Arc<OnceCell<Weak<dyn DynState>>>>>>,
427 pubsub_sender: Arc<dyn PubSubSender>,
428}
429
430#[derive(Debug)]
431enum StateCacheInit {
432 Init(Arc<dyn DynState>),
433 NeedInit(Arc<OnceCell<Weak<dyn DynState>>>),
434}
435
436impl StateCache {
437 pub fn new(
439 cfg: &PersistConfig,
440 metrics: Arc<Metrics>,
441 pubsub_sender: Arc<dyn PubSubSender>,
442 ) -> Self {
443 StateCache {
444 cfg: Arc::new(cfg.clone()),
445 metrics,
446 states: Default::default(),
447 pubsub_sender,
448 }
449 }
450
451 #[cfg(test)]
452 pub(crate) fn new_no_metrics() -> Self {
453 Self::new(
454 &PersistConfig::new_for_tests(),
455 Arc::new(Metrics::new(
456 &PersistConfig::new_for_tests(),
457 &MetricsRegistry::new(),
458 )),
459 Arc::new(crate::rpc::NoopPubSubSender),
460 )
461 }
462
463 pub(crate) async fn get<K, V, T, D, F, InitFn>(
464 &self,
465 shard_id: ShardId,
466 mut init_fn: InitFn,
467 diagnostics: &Diagnostics,
468 ) -> Result<Arc<LockingTypedState<K, V, T, D>>, Box<CodecMismatch>>
469 where
470 K: Debug + Codec,
471 V: Debug + Codec,
472 T: Timestamp + Lattice + Codec64 + Sync,
473 D: Monoid + Codec64,
474 F: Future<Output = Result<TypedState<K, V, T, D>, Box<CodecMismatch>>>,
475 InitFn: FnMut() -> F,
476 {
477 loop {
478 let init = {
479 let mut states = self.states.lock().expect("lock poisoned");
480 let state = states.entry(shard_id).or_default();
481 match state.get() {
482 Some(once_val) => match once_val.upgrade() {
483 Some(x) => StateCacheInit::Init(x),
484 None => {
485 *state = Arc::new(OnceCell::new());
489 StateCacheInit::NeedInit(Arc::clone(state))
490 }
491 },
492 None => StateCacheInit::NeedInit(Arc::clone(state)),
493 }
494 };
495
496 let state = match init {
497 StateCacheInit::Init(x) => x,
498 StateCacheInit::NeedInit(init_once) => {
499 let mut did_init: Option<Arc<LockingTypedState<K, V, T, D>>> = None;
500 let state = init_once
501 .get_or_try_init::<Box<CodecMismatch>, _, _>(|| async {
502 let init_res = init_fn().await;
503 let state = Arc::new(LockingTypedState::new(
504 shard_id,
505 init_res?,
506 Arc::clone(&self.metrics),
507 Arc::clone(&self.cfg),
508 Arc::clone(&self.pubsub_sender).subscribe(&shard_id),
509 diagnostics,
510 ));
511 let ret = Arc::downgrade(&state);
512 did_init = Some(state);
513 let ret: Weak<dyn DynState> = ret;
514 Ok(ret)
515 })
516 .await?;
517 if let Some(x) = did_init {
518 return Ok(x);
522 }
523 let Some(state) = state.upgrade() else {
524 continue;
531 };
532 state
533 }
534 };
535
536 match Arc::clone(&state)
537 .as_any()
538 .downcast::<LockingTypedState<K, V, T, D>>()
539 {
540 Ok(x) => return Ok(x),
541 Err(_) => {
542 return Err(Box::new(CodecMismatch {
543 requested: (
544 K::codec_name(),
545 V::codec_name(),
546 T::codec_name(),
547 D::codec_name(),
548 Some(CodecConcreteType(std::any::type_name::<(K, V, T, D)>())),
549 ),
550 actual: state.codecs(),
551 }));
552 }
553 }
554 }
555 }
556
557 pub(crate) fn get_state_weak(&self, shard_id: &ShardId) -> Option<Weak<dyn DynState>> {
558 self.states
559 .lock()
560 .expect("lock")
561 .get(shard_id)
562 .and_then(|x| x.get())
563 .map(Weak::clone)
564 }
565
566 #[cfg(test)]
567 fn get_cached(&self, shard_id: &ShardId) -> Option<Arc<dyn DynState>> {
568 self.states
569 .lock()
570 .expect("lock")
571 .get(shard_id)
572 .and_then(|x| x.get())
573 .and_then(|x| x.upgrade())
574 }
575
576 #[cfg(test)]
577 fn initialized_count(&self) -> usize {
578 self.states
579 .lock()
580 .expect("lock")
581 .values()
582 .filter(|x| x.initialized())
583 .count()
584 }
585
586 #[cfg(test)]
587 fn strong_count(&self) -> usize {
588 self.states
589 .lock()
590 .expect("lock")
591 .values()
592 .filter(|x| x.get().map_or(false, |x| x.upgrade().is_some()))
593 .count()
594 }
595}
596
597pub(crate) struct LockingTypedState<K, V, T, D> {
601 shard_id: ShardId,
602 state: RwLock<TypedState<K, V, T, D>>,
603 notifier: StateWatchNotifier,
604 cfg: Arc<PersistConfig>,
605 metrics: Arc<Metrics>,
606 shard_metrics: Arc<ShardMetrics>,
607 schema_cache: Arc<dyn Any + Send + Sync>,
610 _subscription_token: Arc<ShardSubscriptionToken>,
611}
612
613impl<K, V, T: Debug, D> Debug for LockingTypedState<K, V, T, D> {
614 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
615 let LockingTypedState {
616 shard_id,
617 state,
618 notifier,
619 cfg: _cfg,
620 metrics: _metrics,
621 shard_metrics: _shard_metrics,
622 schema_cache: _schema_cache,
623 _subscription_token,
624 } = self;
625 f.debug_struct("LockingTypedState")
626 .field("shard_id", shard_id)
627 .field("state", state)
628 .field("notifier", notifier)
629 .finish()
630 }
631}
632
633impl<K: Codec, V: Codec, T, D> LockingTypedState<K, V, T, D> {
634 fn new(
635 shard_id: ShardId,
636 initial_state: TypedState<K, V, T, D>,
637 metrics: Arc<Metrics>,
638 cfg: Arc<PersistConfig>,
639 subscription_token: Arc<ShardSubscriptionToken>,
640 diagnostics: &Diagnostics,
641 ) -> Self {
642 Self {
643 shard_id,
644 notifier: StateWatchNotifier::new(Arc::clone(&metrics)),
645 state: RwLock::new(initial_state),
646 cfg: Arc::clone(&cfg),
647 shard_metrics: metrics.shards.shard(&shard_id, &diagnostics.shard_name),
648 schema_cache: Arc::new(SchemaCacheMaps::<K, V>::new(&metrics.schema)),
649 metrics,
650 _subscription_token: subscription_token,
651 }
652 }
653
654 pub(crate) fn schema_cache(&self) -> Arc<SchemaCacheMaps<K, V>> {
655 Arc::clone(&self.schema_cache)
656 .downcast::<SchemaCacheMaps<K, V>>()
657 .expect("K and V match")
658 }
659}
660
661impl<K, V, T, D> LockingTypedState<K, V, T, D> {
662 pub(crate) fn shard_id(&self) -> &ShardId {
663 &self.shard_id
664 }
665
666 pub(crate) fn read_lock<R, F: FnMut(&TypedState<K, V, T, D>) -> R>(
667 &self,
668 metrics: &LockMetrics,
669 mut f: F,
670 ) -> R {
671 metrics.acquire_count.inc();
672 let state = match self.state.try_read() {
673 Ok(x) => x,
674 Err(TryLockError::WouldBlock) => {
675 metrics.blocking_acquire_count.inc();
676 let start = Instant::now();
677 let state = self.state.read().expect("lock poisoned");
678 metrics
679 .blocking_seconds
680 .inc_by(start.elapsed().as_secs_f64());
681 state
682 }
683 Err(TryLockError::Poisoned(err)) => panic!("state read lock poisoned: {}", err),
684 };
685 f(&state)
686 }
687
688 pub(crate) fn write_lock<R, F: FnOnce(&mut TypedState<K, V, T, D>) -> R>(
689 &self,
690 metrics: &LockMetrics,
691 f: F,
692 ) -> R {
693 metrics.acquire_count.inc();
694 let mut state = match self.state.try_write() {
695 Ok(x) => x,
696 Err(TryLockError::WouldBlock) => {
697 metrics.blocking_acquire_count.inc();
698 let start = Instant::now();
699 let state = self.state.write().expect("lock poisoned");
700 metrics
701 .blocking_seconds
702 .inc_by(start.elapsed().as_secs_f64());
703 state
704 }
705 Err(TryLockError::Poisoned(err)) => panic!("state read lock poisoned: {}", err),
706 };
707 let seqno_before = state.seqno;
708 let ret = f(&mut state);
709 let seqno_after = state.seqno;
710 debug_assert!(seqno_after >= seqno_before);
711 if seqno_after > seqno_before {
712 self.notifier.notify(seqno_after);
713 }
714 drop(state);
717 ret
718 }
719
720 pub(crate) fn notifier(&self) -> &StateWatchNotifier {
721 &self.notifier
722 }
723}
724
725#[cfg(test)]
726mod tests {
727 use std::ops::Deref;
728 use std::str::FromStr;
729 use std::sync::atomic::{AtomicBool, Ordering};
730
731 use futures::stream::{FuturesUnordered, StreamExt};
732 use mz_build_info::DUMMY_BUILD_INFO;
733 use mz_ore::task::spawn;
734 use mz_ore::{assert_err, assert_none};
735
736 use super::*;
737
738 #[mz_ore::test(tokio::test)]
739 #[cfg_attr(miri, ignore)] async fn client_cache() {
741 let cache = PersistClientCache::new(
742 PersistConfig::new_for_tests(),
743 &MetricsRegistry::new(),
744 |_, _| PubSubClientConnection::noop(),
745 );
746 assert_eq!(cache.blob_by_uri.lock().await.len(), 0);
747 assert_eq!(cache.consensus_by_uri.lock().await.len(), 0);
748
749 let _ = cache
751 .open(PersistLocation {
752 blob_uri: SensitiveUrl::from_str("mem://blob_zero").expect("invalid URL"),
753 consensus_uri: SensitiveUrl::from_str("mem://consensus_zero").expect("invalid URL"),
754 })
755 .await
756 .expect("failed to open location");
757 assert_eq!(cache.blob_by_uri.lock().await.len(), 1);
758 assert_eq!(cache.consensus_by_uri.lock().await.len(), 1);
759
760 let _ = cache
763 .open(PersistLocation {
764 blob_uri: SensitiveUrl::from_str("mem://blob_one").expect("invalid URL"),
765 consensus_uri: SensitiveUrl::from_str("mem://consensus_zero").expect("invalid URL"),
766 })
767 .await
768 .expect("failed to open location");
769 assert_eq!(cache.blob_by_uri.lock().await.len(), 2);
770 assert_eq!(cache.consensus_by_uri.lock().await.len(), 1);
771
772 let _ = cache
774 .open(PersistLocation {
775 blob_uri: SensitiveUrl::from_str("mem://blob_one").expect("invalid URL"),
776 consensus_uri: SensitiveUrl::from_str("mem://consensus_one").expect("invalid URL"),
777 })
778 .await
779 .expect("failed to open location");
780 assert_eq!(cache.blob_by_uri.lock().await.len(), 2);
781 assert_eq!(cache.consensus_by_uri.lock().await.len(), 2);
782
783 let _ = cache
785 .open(PersistLocation {
786 blob_uri: SensitiveUrl::from_str("mem://blob_one?foo").expect("invalid URL"),
787 consensus_uri: SensitiveUrl::from_str("mem://consensus_one/bar")
788 .expect("invalid URL"),
789 })
790 .await
791 .expect("failed to open location");
792 assert_eq!(cache.blob_by_uri.lock().await.len(), 3);
793 assert_eq!(cache.consensus_by_uri.lock().await.len(), 3);
794
795 let _ = cache
797 .open(PersistLocation {
798 blob_uri: SensitiveUrl::from_str("mem://user@blob_one").expect("invalid URL"),
799 consensus_uri: SensitiveUrl::from_str("mem://@consensus_one:123")
800 .expect("invalid URL"),
801 })
802 .await
803 .expect("failed to open location");
804 assert_eq!(cache.blob_by_uri.lock().await.len(), 4);
805 assert_eq!(cache.consensus_by_uri.lock().await.len(), 4);
806 }
807
808 #[mz_ore::test(tokio::test)]
809 #[cfg_attr(miri, ignore)] async fn state_cache() {
811 mz_ore::test::init_logging();
812 fn new_state<K, V, T, D>(shard_id: ShardId) -> TypedState<K, V, T, D>
813 where
814 K: Codec,
815 V: Codec,
816 T: Timestamp + Lattice + Codec64,
817 D: Codec64,
818 {
819 TypedState::new(
820 DUMMY_BUILD_INFO.semver_version(),
821 shard_id,
822 "host".into(),
823 0,
824 )
825 }
826 fn assert_same<K, V, T, D>(
827 state1: &LockingTypedState<K, V, T, D>,
828 state2: &LockingTypedState<K, V, T, D>,
829 ) {
830 let pointer1 = format!("{:p}", state1.state.read().expect("lock").deref());
831 let pointer2 = format!("{:p}", state2.state.read().expect("lock").deref());
832 assert_eq!(pointer1, pointer2);
833 }
834
835 let s1 = ShardId::new();
836 let states = Arc::new(StateCache::new_no_metrics());
837
838 assert_eq!(states.states.lock().expect("lock").len(), 0);
840
841 let s = Arc::clone(&states);
843 let res = spawn(|| "test", async move {
844 s.get::<(), (), u64, i64, _, _>(
845 s1,
846 || async { panic!("forced panic") },
847 &Diagnostics::for_tests(),
848 )
849 .await
850 })
851 .into_tokio_handle()
852 .await;
853 assert_err!(res);
854 assert_eq!(states.initialized_count(), 0);
855
856 let res = states
858 .get::<(), (), u64, i64, _, _>(
859 s1,
860 || async {
861 Err(Box::new(CodecMismatch {
862 requested: ("".into(), "".into(), "".into(), "".into(), None),
863 actual: ("".into(), "".into(), "".into(), "".into(), None),
864 }))
865 },
866 &Diagnostics::for_tests(),
867 )
868 .await;
869 assert_err!(res);
870 assert_eq!(states.initialized_count(), 0);
871
872 let did_work = Arc::new(AtomicBool::new(false));
874 let s1_state1 = states
875 .get::<(), (), u64, i64, _, _>(
876 s1,
877 || {
878 let did_work = Arc::clone(&did_work);
879 async move {
880 did_work.store(true, Ordering::SeqCst);
881 Ok(new_state(s1))
882 }
883 },
884 &Diagnostics::for_tests(),
885 )
886 .await
887 .expect("should successfully initialize");
888 assert_eq!(did_work.load(Ordering::SeqCst), true);
889 assert_eq!(states.initialized_count(), 1);
890 assert_eq!(states.strong_count(), 1);
891
892 let did_work = Arc::new(AtomicBool::new(false));
894 let s1_state2 = states
895 .get::<(), (), u64, i64, _, _>(
896 s1,
897 || {
898 let did_work = Arc::clone(&did_work);
899 async move {
900 did_work.store(true, Ordering::SeqCst);
901 did_work.store(true, Ordering::SeqCst);
902 Ok(new_state(s1))
903 }
904 },
905 &Diagnostics::for_tests(),
906 )
907 .await
908 .expect("should successfully initialize");
909 assert_eq!(did_work.load(Ordering::SeqCst), false);
910 assert_eq!(states.initialized_count(), 1);
911 assert_eq!(states.strong_count(), 1);
912 assert_same(&s1_state1, &s1_state2);
913
914 let did_work = Arc::new(AtomicBool::new(false));
916 let res = states
917 .get::<String, (), u64, i64, _, _>(
918 s1,
919 || {
920 let did_work = Arc::clone(&did_work);
921 async move {
922 did_work.store(true, Ordering::SeqCst);
923 Ok(new_state(s1))
924 }
925 },
926 &Diagnostics::for_tests(),
927 )
928 .await;
929 assert_eq!(did_work.load(Ordering::SeqCst), false);
930 assert_eq!(
931 format!("{}", res.expect_err("types shouldn't match")),
932 "requested codecs (\"String\", \"()\", \"u64\", \"i64\", Some(CodecConcreteType(\"(alloc::string::String, (), u64, i64)\"))) did not match ones in durable storage (\"()\", \"()\", \"u64\", \"i64\", Some(CodecConcreteType(\"((), (), u64, i64)\")))"
933 );
934 assert_eq!(states.initialized_count(), 1);
935 assert_eq!(states.strong_count(), 1);
936
937 let s2 = ShardId::new();
939 let s2_state1 = states
940 .get::<String, (), u64, i64, _, _>(
941 s2,
942 || async { Ok(new_state(s2)) },
943 &Diagnostics::for_tests(),
944 )
945 .await
946 .expect("should successfully initialize");
947 assert_eq!(states.initialized_count(), 2);
948 assert_eq!(states.strong_count(), 2);
949 let s2_state2 = states
950 .get::<String, (), u64, i64, _, _>(
951 s2,
952 || async { Ok(new_state(s2)) },
953 &Diagnostics::for_tests(),
954 )
955 .await
956 .expect("should successfully initialize");
957 assert_same(&s2_state1, &s2_state2);
958
959 drop(s1_state1);
962 assert_eq!(states.strong_count(), 2);
963 drop(s1_state2);
964 assert_eq!(states.strong_count(), 1);
965 assert_eq!(states.initialized_count(), 2);
966 assert_none!(states.get_cached(&s1));
967
968 let s1_state1 = states
970 .get::<(), (), u64, i64, _, _>(
971 s1,
972 || async { Ok(new_state(s1)) },
973 &Diagnostics::for_tests(),
974 )
975 .await
976 .expect("should successfully initialize");
977 assert_eq!(states.initialized_count(), 2);
978 assert_eq!(states.strong_count(), 2);
979 drop(s1_state1);
980 assert_eq!(states.strong_count(), 1);
981 }
982
983 #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
984 #[cfg_attr(miri, ignore)] async fn state_cache_concurrency() {
986 mz_ore::test::init_logging();
987
988 const COUNT: usize = 1000;
989 let id = ShardId::new();
990 let cache = StateCache::new_no_metrics();
991 let diagnostics = Diagnostics::for_tests();
992
993 let mut futures = (0..COUNT)
994 .map(|_| {
995 cache.get::<(), (), u64, i64, _, _>(
996 id,
997 || async {
998 Ok(TypedState::new(
999 DUMMY_BUILD_INFO.semver_version(),
1000 id,
1001 "host".into(),
1002 0,
1003 ))
1004 },
1005 &diagnostics,
1006 )
1007 })
1008 .collect::<FuturesUnordered<_>>();
1009
1010 for _ in 0..COUNT {
1011 let _ = futures.next().await.unwrap();
1012 }
1013 }
1014}