1use std::any::Any;
13use std::collections::BTreeMap;
14use std::collections::btree_map::Entry;
15use std::fmt::Debug;
16use std::future::Future;
17use std::sync::{Arc, RwLock, TryLockError, Weak};
18use std::time::{Duration, Instant};
19
20use differential_dataflow::difference::Monoid;
21use differential_dataflow::lattice::Lattice;
22use mz_dyncfg::Config;
23use mz_ore::instrument;
24use mz_ore::metrics::MetricsRegistry;
25use mz_ore::task::{AbortOnDropHandle, JoinHandle};
26use mz_ore::url::SensitiveUrl;
27use mz_persist::cfg::{BlobConfig, ConsensusConfig};
28use mz_persist::location::{
29 BLOB_GET_LIVENESS_KEY, Blob, CONSENSUS_HEAD_LIVENESS_KEY, Consensus, ExternalError, Tasked,
30 VersionedData,
31};
32use mz_persist_types::{Codec, Codec64};
33use timely::progress::Timestamp;
34use tokio::sync::{Mutex, OnceCell};
35use tracing::debug;
36
37use crate::async_runtime::IsolatedRuntime;
38use crate::error::{CodecConcreteType, CodecMismatch};
39use crate::internal::cache::BlobMemCache;
40use crate::internal::machine::retry_external;
41use crate::internal::metrics::{LockMetrics, Metrics, MetricsBlob, MetricsConsensus, ShardMetrics};
42use crate::internal::state::TypedState;
43use crate::internal::watch::{AwaitableState, StateWatchNotifier};
44use crate::rpc::{PubSubClientConnection, PubSubSender, ShardSubscriptionToken};
45use crate::schema::SchemaCacheMaps;
46use crate::{Diagnostics, PersistClient, PersistConfig, PersistLocation, ShardId};
47
48#[derive(Debug)]
57pub struct PersistClientCache {
58 pub cfg: PersistConfig,
60 pub(crate) metrics: Arc<Metrics>,
61 blob_by_uri: Mutex<BTreeMap<SensitiveUrl, (RttLatencyTask, Arc<dyn Blob>)>>,
62 consensus_by_uri: Mutex<BTreeMap<SensitiveUrl, (RttLatencyTask, Arc<dyn Consensus>)>>,
63 isolated_runtime: Arc<IsolatedRuntime>,
64 pub(crate) state_cache: Arc<StateCache>,
65 pubsub_sender: Arc<dyn PubSubSender>,
66 _pubsub_receiver_task: JoinHandle<()>,
67}
68
69#[derive(Debug)]
70struct RttLatencyTask(#[allow(dead_code)] AbortOnDropHandle<()>);
71
72impl PersistClientCache {
73 pub fn new<F>(cfg: PersistConfig, registry: &MetricsRegistry, pubsub: F) -> Self
75 where
76 F: FnOnce(&PersistConfig, Arc<Metrics>) -> PubSubClientConnection,
77 {
78 let metrics = Arc::new(Metrics::new(&cfg, registry));
79 let pubsub_client = pubsub(&cfg, Arc::clone(&metrics));
80
81 let state_cache = Arc::new(StateCache::new(
82 &cfg,
83 Arc::clone(&metrics),
84 Arc::clone(&pubsub_client.sender),
85 ));
86 let _pubsub_receiver_task = crate::rpc::subscribe_state_cache_to_pubsub(
87 Arc::clone(&state_cache),
88 pubsub_client.receiver,
89 );
90 let isolated_runtime =
91 IsolatedRuntime::new(registry, Some(cfg.isolated_runtime_worker_threads));
92
93 PersistClientCache {
94 cfg,
95 metrics,
96 blob_by_uri: Mutex::new(BTreeMap::new()),
97 consensus_by_uri: Mutex::new(BTreeMap::new()),
98 isolated_runtime: Arc::new(isolated_runtime),
99 state_cache,
100 pubsub_sender: pubsub_client.sender,
101 _pubsub_receiver_task,
102 }
103 }
104
105 pub fn new_no_metrics() -> Self {
108 Self::new(
109 PersistConfig::new_for_tests(),
110 &MetricsRegistry::new(),
111 |_, _| PubSubClientConnection::noop(),
112 )
113 }
114
115 #[cfg(feature = "turmoil")]
116 pub fn new_for_turmoil() -> Self {
121 use crate::rpc::NoopPubSubSender;
122
123 let cfg = PersistConfig::new_for_tests();
124 let metrics = Arc::new(Metrics::new(&cfg, &MetricsRegistry::new()));
125
126 let pubsub_sender: Arc<dyn PubSubSender> = Arc::new(NoopPubSubSender);
127 let _pubsub_receiver_task = mz_ore::task::spawn(|| "noop", async {});
128
129 let state_cache = Arc::new(StateCache::new(
130 &cfg,
131 Arc::clone(&metrics),
132 Arc::clone(&pubsub_sender),
133 ));
134 let isolated_runtime = IsolatedRuntime::new_disabled();
135
136 PersistClientCache {
137 cfg,
138 metrics,
139 blob_by_uri: Mutex::new(BTreeMap::new()),
140 consensus_by_uri: Mutex::new(BTreeMap::new()),
141 isolated_runtime: Arc::new(isolated_runtime),
142 state_cache,
143 pubsub_sender,
144 _pubsub_receiver_task,
145 }
146 }
147
148 pub fn cfg(&self) -> &PersistConfig {
150 &self.cfg
151 }
152
153 pub fn metrics(&self) -> &Arc<Metrics> {
155 &self.metrics
156 }
157
158 pub fn shard_metrics(&self, shard_id: &ShardId, name: &str) -> Arc<ShardMetrics> {
160 self.metrics.shards.shard(shard_id, name)
161 }
162
163 pub fn clear_state_cache(&mut self) {
167 self.state_cache = Arc::new(StateCache::new(
168 &self.cfg,
169 Arc::clone(&self.metrics),
170 Arc::clone(&self.pubsub_sender),
171 ))
172 }
173
174 #[instrument(level = "debug")]
179 pub async fn open(&self, location: PersistLocation) -> Result<PersistClient, ExternalError> {
180 let blob = self.open_blob(location.blob_uri).await?;
181 let consensus = self.open_consensus(location.consensus_uri).await?;
182 PersistClient::new(
183 self.cfg.clone(),
184 blob,
185 consensus,
186 Arc::clone(&self.metrics),
187 Arc::clone(&self.isolated_runtime),
188 Arc::clone(&self.state_cache),
189 Arc::clone(&self.pubsub_sender),
190 )
191 }
192
193 const PROMETHEUS_SCRAPE_INTERVAL: Duration = Duration::from_secs(60);
195
196 async fn open_consensus(
197 &self,
198 consensus_uri: SensitiveUrl,
199 ) -> Result<Arc<dyn Consensus>, ExternalError> {
200 let mut consensus_by_uri = self.consensus_by_uri.lock().await;
201 let consensus = match consensus_by_uri.entry(consensus_uri) {
202 Entry::Occupied(x) => Arc::clone(&x.get().1),
203 Entry::Vacant(x) => {
204 let consensus = ConsensusConfig::try_from(
207 x.key(),
208 Box::new(self.cfg.clone()),
209 self.metrics.postgres_consensus.clone(),
210 Arc::clone(&self.cfg().configs),
211 )?;
212 let consensus =
213 retry_external(&self.metrics.retries.external.consensus_open, || {
214 consensus.clone().open()
215 })
216 .await;
217 let consensus =
218 Arc::new(MetricsConsensus::new(consensus, Arc::clone(&self.metrics)));
219 let consensus = Arc::new(Tasked(consensus));
220 let task = consensus_rtt_latency_task(
221 Arc::clone(&consensus),
222 Arc::clone(&self.metrics),
223 Self::PROMETHEUS_SCRAPE_INTERVAL,
224 )
225 .await;
226 Arc::clone(
227 &x.insert((RttLatencyTask(task.abort_on_drop()), consensus))
228 .1,
229 )
230 }
231 };
232 Ok(consensus)
233 }
234
235 async fn open_blob(&self, blob_uri: SensitiveUrl) -> Result<Arc<dyn Blob>, ExternalError> {
236 let mut blob_by_uri = self.blob_by_uri.lock().await;
237 let blob = match blob_by_uri.entry(blob_uri) {
238 Entry::Occupied(x) => Arc::clone(&x.get().1),
239 Entry::Vacant(x) => {
240 let blob = BlobConfig::try_from(
243 x.key(),
244 Box::new(self.cfg.clone()),
245 self.metrics.s3_blob.clone(),
246 Arc::clone(&self.cfg.configs),
247 )
248 .await?;
249 let blob = retry_external(&self.metrics.retries.external.blob_open, || {
250 blob.clone().open()
251 })
252 .await;
253 let blob = Arc::new(MetricsBlob::new(blob, Arc::clone(&self.metrics)));
254 let blob = Arc::new(Tasked(blob));
255 let task = blob_rtt_latency_task(
256 Arc::clone(&blob),
257 Arc::clone(&self.metrics),
258 Self::PROMETHEUS_SCRAPE_INTERVAL,
259 )
260 .await;
261 let blob = BlobMemCache::new(&self.cfg, Arc::clone(&self.metrics), blob);
264 Arc::clone(&x.insert((RttLatencyTask(task.abort_on_drop()), blob)).1)
265 }
266 };
267 Ok(blob)
268 }
269}
270
271#[allow(clippy::unused_async)]
285async fn blob_rtt_latency_task(
286 blob: Arc<Tasked<MetricsBlob>>,
287 metrics: Arc<Metrics>,
288 measurement_interval: Duration,
289) -> JoinHandle<()> {
290 mz_ore::task::spawn(|| "persist::blob_rtt_latency", async move {
291 let mut next_measurement = tokio::time::Instant::now();
294 loop {
295 tokio::time::sleep_until(next_measurement).await;
296 let start = Instant::now();
297 match blob.get(BLOB_GET_LIVENESS_KEY).await {
298 Ok(_) => {
299 metrics.blob.rtt_latency.set(start.elapsed().as_secs_f64());
300 }
301 Err(_) => {
302 }
306 }
307 next_measurement = tokio::time::Instant::now() + measurement_interval;
308 }
309 })
310}
311
312#[allow(clippy::unused_async)]
326async fn consensus_rtt_latency_task(
327 consensus: Arc<Tasked<MetricsConsensus>>,
328 metrics: Arc<Metrics>,
329 measurement_interval: Duration,
330) -> JoinHandle<()> {
331 mz_ore::task::spawn(|| "persist::consensus_rtt_latency", async move {
332 let mut next_measurement = tokio::time::Instant::now();
335 loop {
336 tokio::time::sleep_until(next_measurement).await;
337 let start = Instant::now();
338 match consensus.head(CONSENSUS_HEAD_LIVENESS_KEY).await {
339 Ok(_) => {
340 metrics
341 .consensus
342 .rtt_latency
343 .set(start.elapsed().as_secs_f64());
344 }
345 Err(_) => {
346 }
350 }
351 next_measurement = tokio::time::Instant::now() + measurement_interval;
352 }
353 })
354}
355
356pub(crate) trait DynState: Debug + Send + Sync {
357 fn codecs(&self) -> (String, String, String, String, Option<CodecConcreteType>);
358 fn as_any(self: Arc<Self>) -> Arc<dyn Any + Send + Sync>;
359 fn push_diff(&self, diff: VersionedData);
360}
361
362impl<K, V, T, D> DynState for LockingTypedState<K, V, T, D>
363where
364 K: Codec,
365 V: Codec,
366 T: Timestamp + Lattice + Codec64 + Sync,
367 D: Codec64,
368{
369 fn codecs(&self) -> (String, String, String, String, Option<CodecConcreteType>) {
370 (
371 K::codec_name(),
372 V::codec_name(),
373 T::codec_name(),
374 D::codec_name(),
375 Some(CodecConcreteType(std::any::type_name::<(K, V, T, D)>())),
376 )
377 }
378
379 fn as_any(self: Arc<Self>) -> Arc<dyn Any + Send + Sync> {
380 self
381 }
382
383 fn push_diff(&self, diff: VersionedData) {
384 self.write_lock(&self.metrics.locks.applier_write, |state| {
385 let seqno_before = state.seqno;
386 state.apply_encoded_diffs(&self.cfg, &self.metrics, std::iter::once(&diff));
387 let seqno_after = state.seqno;
388 assert!(seqno_after >= seqno_before);
389
390 if seqno_before != seqno_after {
391 debug!(
392 "applied pushed diff {}. seqno {} -> {}.",
393 state.shard_id, seqno_before, state.seqno
394 );
395 self.shard_metrics.pubsub_push_diff_applied.inc();
396 } else {
397 debug!(
398 "failed to apply pushed diff {}. seqno {} vs diff {}",
399 state.shard_id, seqno_before, diff.seqno
400 );
401 if diff.seqno <= seqno_before {
402 self.shard_metrics.pubsub_push_diff_not_applied_stale.inc();
403 } else {
404 self.shard_metrics
405 .pubsub_push_diff_not_applied_out_of_order
406 .inc();
407 }
408 }
409 })
410 }
411}
412
413#[derive(Debug)]
424pub struct StateCache {
425 cfg: Arc<PersistConfig>,
426 pub(crate) metrics: Arc<Metrics>,
427 states: Arc<std::sync::Mutex<BTreeMap<ShardId, Arc<OnceCell<Weak<dyn DynState>>>>>>,
428 pubsub_sender: Arc<dyn PubSubSender>,
429}
430
431#[derive(Debug)]
432enum StateCacheInit {
433 Init(Arc<dyn DynState>),
434 NeedInit(Arc<OnceCell<Weak<dyn DynState>>>),
435}
436
437impl StateCache {
438 pub fn new(
440 cfg: &PersistConfig,
441 metrics: Arc<Metrics>,
442 pubsub_sender: Arc<dyn PubSubSender>,
443 ) -> Self {
444 StateCache {
445 cfg: Arc::new(cfg.clone()),
446 metrics,
447 states: Default::default(),
448 pubsub_sender,
449 }
450 }
451
452 #[cfg(test)]
453 pub(crate) fn new_no_metrics() -> Self {
454 Self::new(
455 &PersistConfig::new_for_tests(),
456 Arc::new(Metrics::new(
457 &PersistConfig::new_for_tests(),
458 &MetricsRegistry::new(),
459 )),
460 Arc::new(crate::rpc::NoopPubSubSender),
461 )
462 }
463
464 pub(crate) async fn get<K, V, T, D, F, InitFn>(
465 &self,
466 shard_id: ShardId,
467 mut init_fn: InitFn,
468 diagnostics: &Diagnostics,
469 ) -> Result<Arc<LockingTypedState<K, V, T, D>>, Box<CodecMismatch>>
470 where
471 K: Debug + Codec,
472 V: Debug + Codec,
473 T: Timestamp + Lattice + Codec64 + Sync,
474 D: Monoid + Codec64,
475 F: Future<Output = Result<TypedState<K, V, T, D>, Box<CodecMismatch>>>,
476 InitFn: FnMut() -> F,
477 {
478 loop {
479 let init = {
480 let mut states = self.states.lock().expect("lock poisoned");
481 let state = states.entry(shard_id).or_default();
482 match state.get() {
483 Some(once_val) => match once_val.upgrade() {
484 Some(x) => StateCacheInit::Init(x),
485 None => {
486 *state = Arc::new(OnceCell::new());
490 StateCacheInit::NeedInit(Arc::clone(state))
491 }
492 },
493 None => StateCacheInit::NeedInit(Arc::clone(state)),
494 }
495 };
496
497 let state = match init {
498 StateCacheInit::Init(x) => x,
499 StateCacheInit::NeedInit(init_once) => {
500 let mut did_init: Option<Arc<LockingTypedState<K, V, T, D>>> = None;
501 let state = init_once
502 .get_or_try_init::<Box<CodecMismatch>, _, _>(|| async {
503 let init_res = init_fn().await;
504 let state = Arc::new(LockingTypedState::new(
505 shard_id,
506 init_res?,
507 Arc::clone(&self.metrics),
508 Arc::clone(&self.cfg),
509 Arc::clone(&self.pubsub_sender).subscribe(&shard_id),
510 diagnostics,
511 ));
512 let ret = Arc::downgrade(&state);
513 did_init = Some(state);
514 let ret: Weak<dyn DynState> = ret;
515 Ok(ret)
516 })
517 .await?;
518 if let Some(x) = did_init {
519 return Ok(x);
523 }
524 let Some(state) = state.upgrade() else {
525 continue;
532 };
533 state
534 }
535 };
536
537 match Arc::clone(&state)
538 .as_any()
539 .downcast::<LockingTypedState<K, V, T, D>>()
540 {
541 Ok(x) => return Ok(x),
542 Err(_) => {
543 return Err(Box::new(CodecMismatch {
544 requested: (
545 K::codec_name(),
546 V::codec_name(),
547 T::codec_name(),
548 D::codec_name(),
549 Some(CodecConcreteType(std::any::type_name::<(K, V, T, D)>())),
550 ),
551 actual: state.codecs(),
552 }));
553 }
554 }
555 }
556 }
557
558 pub(crate) fn get_state_weak(&self, shard_id: &ShardId) -> Option<Weak<dyn DynState>> {
559 self.states
560 .lock()
561 .expect("lock")
562 .get(shard_id)
563 .and_then(|x| x.get())
564 .map(Weak::clone)
565 }
566
567 #[cfg(test)]
568 fn get_cached(&self, shard_id: &ShardId) -> Option<Arc<dyn DynState>> {
569 self.states
570 .lock()
571 .expect("lock")
572 .get(shard_id)
573 .and_then(|x| x.get())
574 .and_then(|x| x.upgrade())
575 }
576
577 #[cfg(test)]
578 fn initialized_count(&self) -> usize {
579 self.states
580 .lock()
581 .expect("lock")
582 .values()
583 .filter(|x| x.initialized())
584 .count()
585 }
586
587 #[cfg(test)]
588 fn strong_count(&self) -> usize {
589 self.states
590 .lock()
591 .expect("lock")
592 .values()
593 .filter(|x| x.get().map_or(false, |x| x.upgrade().is_some()))
594 .count()
595 }
596}
597
598pub(crate) struct LockingTypedState<K, V, T, D> {
602 shard_id: ShardId,
603 state: RwLock<TypedState<K, V, T, D>>,
604 notifier: StateWatchNotifier,
605 cfg: Arc<PersistConfig>,
606 metrics: Arc<Metrics>,
607 shard_metrics: Arc<ShardMetrics>,
608 update_semaphore: AwaitableState<Option<tokio::time::Instant>>,
609 schema_cache: Arc<dyn Any + Send + Sync>,
612 _subscription_token: Arc<ShardSubscriptionToken>,
613}
614
615impl<K, V, T: Debug, D> Debug for LockingTypedState<K, V, T, D> {
616 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
617 let LockingTypedState {
618 shard_id,
619 state,
620 notifier,
621 cfg: _cfg,
622 metrics: _metrics,
623 shard_metrics: _shard_metrics,
624 update_semaphore: _,
625 schema_cache: _schema_cache,
626 _subscription_token,
627 } = self;
628 f.debug_struct("LockingTypedState")
629 .field("shard_id", shard_id)
630 .field("state", state)
631 .field("notifier", notifier)
632 .finish()
633 }
634}
635
636impl<K: Codec, V: Codec, T, D> LockingTypedState<K, V, T, D> {
637 fn new(
638 shard_id: ShardId,
639 initial_state: TypedState<K, V, T, D>,
640 metrics: Arc<Metrics>,
641 cfg: Arc<PersistConfig>,
642 subscription_token: Arc<ShardSubscriptionToken>,
643 diagnostics: &Diagnostics,
644 ) -> Self {
645 Self {
646 shard_id,
647 notifier: StateWatchNotifier::new(Arc::clone(&metrics)),
648 state: RwLock::new(initial_state),
649 cfg: Arc::clone(&cfg),
650 shard_metrics: metrics.shards.shard(&shard_id, &diagnostics.shard_name),
651 update_semaphore: AwaitableState::new(None),
652 schema_cache: Arc::new(SchemaCacheMaps::<K, V>::new(&metrics.schema)),
653 metrics,
654 _subscription_token: subscription_token,
655 }
656 }
657
658 pub(crate) fn schema_cache(&self) -> Arc<SchemaCacheMaps<K, V>> {
659 Arc::clone(&self.schema_cache)
660 .downcast::<SchemaCacheMaps<K, V>>()
661 .expect("K and V match")
662 }
663}
664
665pub(crate) const STATE_UPDATE_LEASE_TIMEOUT: Config<Duration> = Config::new(
666 "persist_state_update_lease_timeout",
667 Duration::from_secs(1),
668 "The amount of time for a command to wait for a previous command to finish before executing. \
669 (If zero, commands will not wait for others to complete.) Higher values reduce database contention \
670 at the cost of higher worst-case latencies for individual requests.",
671);
672
673impl<K, V, T, D> LockingTypedState<K, V, T, D> {
674 pub(crate) fn shard_id(&self) -> &ShardId {
675 &self.shard_id
676 }
677
678 pub(crate) fn read_lock<R, F: FnMut(&TypedState<K, V, T, D>) -> R>(
679 &self,
680 metrics: &LockMetrics,
681 mut f: F,
682 ) -> R {
683 metrics.acquire_count.inc();
684 let state = match self.state.try_read() {
685 Ok(x) => x,
686 Err(TryLockError::WouldBlock) => {
687 metrics.blocking_acquire_count.inc();
688 let start = Instant::now();
689 let state = self.state.read().expect("lock poisoned");
690 metrics
691 .blocking_seconds
692 .inc_by(start.elapsed().as_secs_f64());
693 state
694 }
695 Err(TryLockError::Poisoned(err)) => panic!("state read lock poisoned: {}", err),
696 };
697 f(&state)
698 }
699
700 pub(crate) fn write_lock<R, F: FnOnce(&mut TypedState<K, V, T, D>) -> R>(
701 &self,
702 metrics: &LockMetrics,
703 f: F,
704 ) -> R {
705 metrics.acquire_count.inc();
706 let mut state = match self.state.try_write() {
707 Ok(x) => x,
708 Err(TryLockError::WouldBlock) => {
709 metrics.blocking_acquire_count.inc();
710 let start = Instant::now();
711 let state = self.state.write().expect("lock poisoned");
712 metrics
713 .blocking_seconds
714 .inc_by(start.elapsed().as_secs_f64());
715 state
716 }
717 Err(TryLockError::Poisoned(err)) => panic!("state read lock poisoned: {}", err),
718 };
719 let seqno_before = state.seqno;
720 let ret = f(&mut state);
721 let seqno_after = state.seqno;
722 debug_assert!(seqno_after >= seqno_before);
723 if seqno_after > seqno_before {
724 self.notifier.notify(seqno_after);
725 }
726 drop(state);
729 ret
730 }
731
732 pub(crate) async fn lease_for_update(&self) -> impl Drop {
740 use tokio::time::Instant;
741
742 let timeout = STATE_UPDATE_LEASE_TIMEOUT.get(&self.cfg);
743
744 struct DropLease(Option<(AwaitableState<Option<Instant>>, Instant)>);
745
746 impl Drop for DropLease {
747 fn drop(&mut self) {
748 if let Some((state, time)) = self.0.take() {
749 state.maybe_modify(|s| {
751 if s.is_some_and(|t| t == time) {
752 *s.get_mut() = None;
753 }
754 })
755 }
756 }
757 }
758
759 if timeout.is_zero() {
761 return DropLease(None);
762 }
763
764 let timeout_state = self.update_semaphore.clone();
765 loop {
766 let now = tokio::time::Instant::now();
767 let expires_at = now + timeout;
768 let maybe_leased = timeout_state.maybe_modify(|state| {
770 if let Some(other_expires_at) = **state
771 && other_expires_at > now
772 {
773 Err(other_expires_at)
775 } else {
776 *state.get_mut() = Some(expires_at);
777 Ok(())
778 }
779 });
780
781 match maybe_leased {
782 Ok(()) => {
783 break DropLease(Some((timeout_state, expires_at)));
784 }
785 Err(other_expires_at) => {
786 let _ = tokio::time::timeout_at(
791 other_expires_at,
792 timeout_state.wait_while(|s| s.is_some()),
793 )
794 .await;
795 }
796 }
797 }
798 }
799
800 pub(crate) fn notifier(&self) -> &StateWatchNotifier {
801 &self.notifier
802 }
803}
804
805#[cfg(test)]
806mod tests {
807 use std::ops::Deref;
808 use std::pin::pin;
809 use std::str::FromStr;
810 use std::sync::atomic::{AtomicBool, Ordering};
811
812 use super::*;
813 use crate::rpc::NoopPubSubSender;
814 use futures::stream::{FuturesUnordered, StreamExt};
815 use mz_build_info::DUMMY_BUILD_INFO;
816 use mz_ore::task::spawn;
817 use mz_ore::{assert_err, assert_none};
818 use tokio::sync::oneshot;
819
820 #[mz_ore::test(tokio::test)]
821 #[cfg_attr(miri, ignore)] async fn client_cache() {
823 let cache = PersistClientCache::new(
824 PersistConfig::new_for_tests(),
825 &MetricsRegistry::new(),
826 |_, _| PubSubClientConnection::noop(),
827 );
828 assert_eq!(cache.blob_by_uri.lock().await.len(), 0);
829 assert_eq!(cache.consensus_by_uri.lock().await.len(), 0);
830
831 let _ = cache
833 .open(PersistLocation {
834 blob_uri: SensitiveUrl::from_str("mem://blob_zero").expect("invalid URL"),
835 consensus_uri: SensitiveUrl::from_str("mem://consensus_zero").expect("invalid URL"),
836 })
837 .await
838 .expect("failed to open location");
839 assert_eq!(cache.blob_by_uri.lock().await.len(), 1);
840 assert_eq!(cache.consensus_by_uri.lock().await.len(), 1);
841
842 let _ = cache
845 .open(PersistLocation {
846 blob_uri: SensitiveUrl::from_str("mem://blob_one").expect("invalid URL"),
847 consensus_uri: SensitiveUrl::from_str("mem://consensus_zero").expect("invalid URL"),
848 })
849 .await
850 .expect("failed to open location");
851 assert_eq!(cache.blob_by_uri.lock().await.len(), 2);
852 assert_eq!(cache.consensus_by_uri.lock().await.len(), 1);
853
854 let _ = cache
856 .open(PersistLocation {
857 blob_uri: SensitiveUrl::from_str("mem://blob_one").expect("invalid URL"),
858 consensus_uri: SensitiveUrl::from_str("mem://consensus_one").expect("invalid URL"),
859 })
860 .await
861 .expect("failed to open location");
862 assert_eq!(cache.blob_by_uri.lock().await.len(), 2);
863 assert_eq!(cache.consensus_by_uri.lock().await.len(), 2);
864
865 let _ = cache
867 .open(PersistLocation {
868 blob_uri: SensitiveUrl::from_str("mem://blob_one?foo").expect("invalid URL"),
869 consensus_uri: SensitiveUrl::from_str("mem://consensus_one/bar")
870 .expect("invalid URL"),
871 })
872 .await
873 .expect("failed to open location");
874 assert_eq!(cache.blob_by_uri.lock().await.len(), 3);
875 assert_eq!(cache.consensus_by_uri.lock().await.len(), 3);
876
877 let _ = cache
879 .open(PersistLocation {
880 blob_uri: SensitiveUrl::from_str("mem://user@blob_one").expect("invalid URL"),
881 consensus_uri: SensitiveUrl::from_str("mem://@consensus_one:123")
882 .expect("invalid URL"),
883 })
884 .await
885 .expect("failed to open location");
886 assert_eq!(cache.blob_by_uri.lock().await.len(), 4);
887 assert_eq!(cache.consensus_by_uri.lock().await.len(), 4);
888 }
889
890 #[mz_ore::test(tokio::test)]
891 #[cfg_attr(miri, ignore)] async fn state_cache() {
893 mz_ore::test::init_logging();
894 fn new_state<K, V, T, D>(shard_id: ShardId) -> TypedState<K, V, T, D>
895 where
896 K: Codec,
897 V: Codec,
898 T: Timestamp + Lattice + Codec64,
899 D: Codec64,
900 {
901 TypedState::new(
902 DUMMY_BUILD_INFO.semver_version(),
903 shard_id,
904 "host".into(),
905 0,
906 )
907 }
908 fn assert_same<K, V, T, D>(
909 state1: &LockingTypedState<K, V, T, D>,
910 state2: &LockingTypedState<K, V, T, D>,
911 ) {
912 let pointer1 = format!("{:p}", state1.state.read().expect("lock").deref());
913 let pointer2 = format!("{:p}", state2.state.read().expect("lock").deref());
914 assert_eq!(pointer1, pointer2);
915 }
916
917 let s1 = ShardId::new();
918 let states = Arc::new(StateCache::new_no_metrics());
919
920 assert_eq!(states.states.lock().expect("lock").len(), 0);
922
923 let s = Arc::clone(&states);
925 let res = spawn(|| "test", async move {
926 s.get::<(), (), u64, i64, _, _>(
927 s1,
928 || async { panic!("forced panic") },
929 &Diagnostics::for_tests(),
930 )
931 .await
932 })
933 .into_tokio_handle()
934 .await;
935 assert_err!(res);
936 assert_eq!(states.initialized_count(), 0);
937
938 let res = states
940 .get::<(), (), u64, i64, _, _>(
941 s1,
942 || async {
943 Err(Box::new(CodecMismatch {
944 requested: ("".into(), "".into(), "".into(), "".into(), None),
945 actual: ("".into(), "".into(), "".into(), "".into(), None),
946 }))
947 },
948 &Diagnostics::for_tests(),
949 )
950 .await;
951 assert_err!(res);
952 assert_eq!(states.initialized_count(), 0);
953
954 let did_work = Arc::new(AtomicBool::new(false));
956 let s1_state1 = states
957 .get::<(), (), u64, i64, _, _>(
958 s1,
959 || {
960 let did_work = Arc::clone(&did_work);
961 async move {
962 did_work.store(true, Ordering::SeqCst);
963 Ok(new_state(s1))
964 }
965 },
966 &Diagnostics::for_tests(),
967 )
968 .await
969 .expect("should successfully initialize");
970 assert_eq!(did_work.load(Ordering::SeqCst), true);
971 assert_eq!(states.initialized_count(), 1);
972 assert_eq!(states.strong_count(), 1);
973
974 let did_work = Arc::new(AtomicBool::new(false));
976 let s1_state2 = states
977 .get::<(), (), u64, i64, _, _>(
978 s1,
979 || {
980 let did_work = Arc::clone(&did_work);
981 async move {
982 did_work.store(true, Ordering::SeqCst);
983 did_work.store(true, Ordering::SeqCst);
984 Ok(new_state(s1))
985 }
986 },
987 &Diagnostics::for_tests(),
988 )
989 .await
990 .expect("should successfully initialize");
991 assert_eq!(did_work.load(Ordering::SeqCst), false);
992 assert_eq!(states.initialized_count(), 1);
993 assert_eq!(states.strong_count(), 1);
994 assert_same(&s1_state1, &s1_state2);
995
996 let did_work = Arc::new(AtomicBool::new(false));
998 let res = states
999 .get::<String, (), u64, i64, _, _>(
1000 s1,
1001 || {
1002 let did_work = Arc::clone(&did_work);
1003 async move {
1004 did_work.store(true, Ordering::SeqCst);
1005 Ok(new_state(s1))
1006 }
1007 },
1008 &Diagnostics::for_tests(),
1009 )
1010 .await;
1011 assert_eq!(did_work.load(Ordering::SeqCst), false);
1012 assert_eq!(
1013 format!("{}", res.expect_err("types shouldn't match")),
1014 "requested codecs (\"String\", \"()\", \"u64\", \"i64\", Some(CodecConcreteType(\"(alloc::string::String, (), u64, i64)\"))) did not match ones in durable storage (\"()\", \"()\", \"u64\", \"i64\", Some(CodecConcreteType(\"((), (), u64, i64)\")))"
1015 );
1016 assert_eq!(states.initialized_count(), 1);
1017 assert_eq!(states.strong_count(), 1);
1018
1019 let s2 = ShardId::new();
1021 let s2_state1 = states
1022 .get::<String, (), u64, i64, _, _>(
1023 s2,
1024 || async { Ok(new_state(s2)) },
1025 &Diagnostics::for_tests(),
1026 )
1027 .await
1028 .expect("should successfully initialize");
1029 assert_eq!(states.initialized_count(), 2);
1030 assert_eq!(states.strong_count(), 2);
1031 let s2_state2 = states
1032 .get::<String, (), u64, i64, _, _>(
1033 s2,
1034 || async { Ok(new_state(s2)) },
1035 &Diagnostics::for_tests(),
1036 )
1037 .await
1038 .expect("should successfully initialize");
1039 assert_same(&s2_state1, &s2_state2);
1040
1041 drop(s1_state1);
1044 assert_eq!(states.strong_count(), 2);
1045 drop(s1_state2);
1046 assert_eq!(states.strong_count(), 1);
1047 assert_eq!(states.initialized_count(), 2);
1048 assert_none!(states.get_cached(&s1));
1049
1050 let s1_state1 = states
1052 .get::<(), (), u64, i64, _, _>(
1053 s1,
1054 || async { Ok(new_state(s1)) },
1055 &Diagnostics::for_tests(),
1056 )
1057 .await
1058 .expect("should successfully initialize");
1059 assert_eq!(states.initialized_count(), 2);
1060 assert_eq!(states.strong_count(), 2);
1061 drop(s1_state1);
1062 assert_eq!(states.strong_count(), 1);
1063 }
1064
1065 #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
1066 #[cfg_attr(miri, ignore)] async fn state_cache_concurrency() {
1068 mz_ore::test::init_logging();
1069
1070 const COUNT: usize = 1000;
1071 let id = ShardId::new();
1072 let cache = StateCache::new_no_metrics();
1073 let diagnostics = Diagnostics::for_tests();
1074
1075 let mut futures = (0..COUNT)
1076 .map(|_| {
1077 cache.get::<(), (), u64, i64, _, _>(
1078 id,
1079 || async {
1080 Ok(TypedState::new(
1081 DUMMY_BUILD_INFO.semver_version(),
1082 id,
1083 "host".into(),
1084 0,
1085 ))
1086 },
1087 &diagnostics,
1088 )
1089 })
1090 .collect::<FuturesUnordered<_>>();
1091
1092 for _ in 0..COUNT {
1093 let _ = futures.next().await.unwrap();
1094 }
1095 }
1096
1097 #[mz_ore::test(tokio::test)]
1098 #[cfg_attr(miri, ignore)] async fn update_semaphore() {
1100 mz_ore::test::init_logging();
1103
1104 let shard_id = ShardId::new();
1105 let persist_config = Arc::new(PersistConfig::new_for_tests());
1106 let pubsub = Arc::new(NoopPubSubSender);
1107 let state: LockingTypedState<String, (), u64, i64> = LockingTypedState::new(
1108 shard_id,
1109 TypedState::new(
1110 DUMMY_BUILD_INFO.semver_version(),
1111 shard_id,
1112 "host".into(),
1113 0,
1114 ),
1115 Arc::new(Metrics::new(&*persist_config, &MetricsRegistry::new())),
1116 persist_config,
1117 pubsub.subscribe(&shard_id),
1118 &Diagnostics::for_tests(),
1119 );
1120
1121 let mk_future = || {
1124 let (tx, rx) = oneshot::channel();
1125 let future = async {
1126 let lease = state.lease_for_update().await;
1127 let () = rx.await.unwrap();
1128 drop(lease);
1129 };
1130 (future, tx)
1131 };
1132
1133 let (one, _one_tx) = mk_future();
1134 let (two, _two_tx) = mk_future();
1135 let (three, three_tx) = mk_future();
1136 let mut one = pin!(one);
1137 let mut two = pin!(two);
1138 let mut three = pin!(three);
1139
1140 tokio::select! { biased;
1142 _ = &mut one => { unreachable!() }
1143 _ = &mut two => { unreachable!() }
1144 _ = &mut three => { unreachable!() }
1145 _ = async {} => {}
1146 }
1147
1148 three_tx.send(()).unwrap();
1150
1151 tokio::select! { biased;
1154 _ = &mut one => { unreachable!() }
1155 _ = &mut three => { }
1156 }
1157 }
1158
1159 #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
1160 #[cfg_attr(miri, ignore)] async fn update_semaphore_stress() {
1162 mz_ore::test::init_logging();
1165
1166 const TIMEOUT: Duration = Duration::from_millis(100);
1167 const COUNT: u64 = 100;
1168
1169 let shard_id = ShardId::new();
1170 let persist_config = Arc::new(PersistConfig::new_for_tests());
1171 persist_config.set_config(&STATE_UPDATE_LEASE_TIMEOUT, TIMEOUT);
1172 let pubsub = Arc::new(NoopPubSubSender);
1173 let state: LockingTypedState<String, (), u64, i64> = LockingTypedState::new(
1174 shard_id,
1175 TypedState::new(
1176 DUMMY_BUILD_INFO.semver_version(),
1177 shard_id,
1178 "host".into(),
1179 0,
1180 ),
1181 Arc::new(Metrics::new(&*persist_config, &MetricsRegistry::new())),
1182 persist_config,
1183 pubsub.subscribe(&shard_id),
1184 &Diagnostics::for_tests(),
1185 );
1186
1187 let mut futures = (0..(COUNT * 3))
1188 .map(async |i| {
1189 state.lease_for_update().await;
1190 match i % 3 {
1192 0 => {
1193 let () = std::future::pending().await;
1194 }
1195 1 => {
1196 tokio::time::sleep(Duration::from_millis(i)).await;
1197 }
1198 _ => {
1199 tokio::time::sleep(Duration::from_millis(i) + TIMEOUT).await;
1200 }
1201 }
1202 })
1203 .collect::<FuturesUnordered<_>>();
1204
1205 for _ in 0..(COUNT * 2) {
1207 let _ = futures.next().await.unwrap();
1208 }
1209 }
1210}