1use std::any::Any;
13use std::collections::BTreeMap;
14use std::collections::btree_map::Entry;
15use std::fmt::Debug;
16use std::future::Future;
17use std::sync::{Arc, RwLock, TryLockError, Weak};
18use std::time::{Duration, Instant};
19
20use differential_dataflow::difference::Monoid;
21use differential_dataflow::lattice::Lattice;
22use mz_dyncfg::Config;
23use mz_ore::instrument;
24use mz_ore::metrics::MetricsRegistry;
25use mz_ore::task::{AbortOnDropHandle, JoinHandle};
26use mz_ore::url::SensitiveUrl;
27use mz_persist::cfg::{BlobConfig, ConsensusConfig};
28use mz_persist::location::{
29 BLOB_GET_LIVENESS_KEY, Blob, CONSENSUS_HEAD_LIVENESS_KEY, Consensus, ExternalError, Tasked,
30 VersionedData,
31};
32use mz_persist_types::{Codec, Codec64};
33use timely::progress::Timestamp;
34use tokio::sync::{Mutex, OnceCell};
35use tracing::debug;
36
37use crate::async_runtime::IsolatedRuntime;
38use crate::error::{CodecConcreteType, CodecMismatch};
39use crate::internal::cache::BlobMemCache;
40use crate::internal::machine::retry_external;
41use crate::internal::metrics::{LockMetrics, Metrics, MetricsBlob, MetricsConsensus, ShardMetrics};
42use crate::internal::state::TypedState;
43use crate::internal::watch::{AwaitableState, StateWatchNotifier};
44use crate::rpc::{PubSubClientConnection, PubSubSender, ShardSubscriptionToken};
45use crate::schema::SchemaCacheMaps;
46use crate::{Diagnostics, PersistClient, PersistConfig, PersistLocation, ShardId};
47
48#[derive(Debug)]
57pub struct PersistClientCache {
58 pub cfg: PersistConfig,
60 pub(crate) metrics: Arc<Metrics>,
61 blob_by_uri: Mutex<BTreeMap<SensitiveUrl, (RttLatencyTask, Arc<dyn Blob>)>>,
62 consensus_by_uri: Mutex<BTreeMap<SensitiveUrl, (RttLatencyTask, Arc<dyn Consensus>)>>,
63 isolated_runtime: Arc<IsolatedRuntime>,
64 pub(crate) state_cache: Arc<StateCache>,
65 pubsub_sender: Arc<dyn PubSubSender>,
66 _pubsub_receiver_task: JoinHandle<()>,
67}
68
69#[derive(Debug)]
70struct RttLatencyTask(#[allow(dead_code)] AbortOnDropHandle<()>);
71
72impl PersistClientCache {
73 pub fn new<F>(cfg: PersistConfig, registry: &MetricsRegistry, pubsub: F) -> Self
75 where
76 F: FnOnce(&PersistConfig, Arc<Metrics>) -> PubSubClientConnection,
77 {
78 let metrics = Arc::new(Metrics::new(&cfg, registry));
79 let pubsub_client = pubsub(&cfg, Arc::clone(&metrics));
80
81 let state_cache = Arc::new(StateCache::new(
82 &cfg,
83 Arc::clone(&metrics),
84 Arc::clone(&pubsub_client.sender),
85 ));
86 let _pubsub_receiver_task = crate::rpc::subscribe_state_cache_to_pubsub(
87 Arc::clone(&state_cache),
88 pubsub_client.receiver,
89 );
90 let isolated_runtime =
91 IsolatedRuntime::new(registry, Some(cfg.isolated_runtime_worker_threads));
92
93 PersistClientCache {
94 cfg,
95 metrics,
96 blob_by_uri: Mutex::new(BTreeMap::new()),
97 consensus_by_uri: Mutex::new(BTreeMap::new()),
98 isolated_runtime: Arc::new(isolated_runtime),
99 state_cache,
100 pubsub_sender: pubsub_client.sender,
101 _pubsub_receiver_task,
102 }
103 }
104
105 pub fn new_no_metrics() -> Self {
108 Self::new(
109 PersistConfig::new_for_tests(),
110 &MetricsRegistry::new(),
111 |_, _| PubSubClientConnection::noop(),
112 )
113 }
114
115 #[cfg(feature = "turmoil")]
116 pub fn new_for_turmoil() -> Self {
121 use crate::rpc::NoopPubSubSender;
122
123 let cfg = PersistConfig::new_for_tests();
124 let metrics = Arc::new(Metrics::new(&cfg, &MetricsRegistry::new()));
125
126 let pubsub_sender: Arc<dyn PubSubSender> = Arc::new(NoopPubSubSender);
127 let _pubsub_receiver_task = mz_ore::task::spawn(|| "noop", async {});
128
129 let state_cache = Arc::new(StateCache::new(
130 &cfg,
131 Arc::clone(&metrics),
132 Arc::clone(&pubsub_sender),
133 ));
134 let isolated_runtime = IsolatedRuntime::new_disabled();
135
136 PersistClientCache {
137 cfg,
138 metrics,
139 blob_by_uri: Mutex::new(BTreeMap::new()),
140 consensus_by_uri: Mutex::new(BTreeMap::new()),
141 isolated_runtime: Arc::new(isolated_runtime),
142 state_cache,
143 pubsub_sender,
144 _pubsub_receiver_task,
145 }
146 }
147
148 pub fn cfg(&self) -> &PersistConfig {
150 &self.cfg
151 }
152
153 pub fn metrics(&self) -> &Arc<Metrics> {
155 &self.metrics
156 }
157
158 pub fn shard_metrics(&self, shard_id: &ShardId, name: &str) -> Arc<ShardMetrics> {
160 self.metrics.shards.shard(shard_id, name)
161 }
162
163 pub fn clear_state_cache(&mut self) {
167 self.state_cache = Arc::new(StateCache::new(
168 &self.cfg,
169 Arc::clone(&self.metrics),
170 Arc::clone(&self.pubsub_sender),
171 ))
172 }
173
174 #[instrument(level = "debug")]
179 pub async fn open(&self, location: PersistLocation) -> Result<PersistClient, ExternalError> {
180 let blob = self.open_blob(location.blob_uri).await?;
181 let consensus = self.open_consensus(location.consensus_uri).await?;
182 PersistClient::new(
183 self.cfg.clone(),
184 blob,
185 consensus,
186 Arc::clone(&self.metrics),
187 Arc::clone(&self.isolated_runtime),
188 Arc::clone(&self.state_cache),
189 Arc::clone(&self.pubsub_sender),
190 )
191 }
192
193 const PROMETHEUS_SCRAPE_INTERVAL: Duration = Duration::from_secs(60);
195
196 async fn open_consensus(
197 &self,
198 consensus_uri: SensitiveUrl,
199 ) -> Result<Arc<dyn Consensus>, ExternalError> {
200 let mut consensus_by_uri = self.consensus_by_uri.lock().await;
201 let consensus = match consensus_by_uri.entry(consensus_uri) {
202 Entry::Occupied(x) => Arc::clone(&x.get().1),
203 Entry::Vacant(x) => {
204 let consensus = ConsensusConfig::try_from(
207 x.key(),
208 Box::new(self.cfg.clone()),
209 self.metrics.postgres_consensus.clone(),
210 Arc::clone(&self.cfg().configs),
211 )?;
212 let consensus =
213 retry_external(&self.metrics.retries.external.consensus_open, || {
214 consensus.clone().open()
215 })
216 .await;
217 let consensus =
218 Arc::new(MetricsConsensus::new(consensus, Arc::clone(&self.metrics)));
219 let consensus = Arc::new(Tasked(consensus));
220 let task = consensus_rtt_latency_task(
221 Arc::clone(&consensus),
222 Arc::clone(&self.metrics),
223 Self::PROMETHEUS_SCRAPE_INTERVAL,
224 )
225 .await;
226 Arc::clone(
227 &x.insert((RttLatencyTask(task.abort_on_drop()), consensus))
228 .1,
229 )
230 }
231 };
232 Ok(consensus)
233 }
234
235 async fn open_blob(&self, blob_uri: SensitiveUrl) -> Result<Arc<dyn Blob>, ExternalError> {
236 let mut blob_by_uri = self.blob_by_uri.lock().await;
237 let blob = match blob_by_uri.entry(blob_uri) {
238 Entry::Occupied(x) => Arc::clone(&x.get().1),
239 Entry::Vacant(x) => {
240 let blob = BlobConfig::try_from(
243 x.key(),
244 Box::new(self.cfg.clone()),
245 self.metrics.s3_blob.clone(),
246 )
247 .await?;
248 let blob = retry_external(&self.metrics.retries.external.blob_open, || {
249 blob.clone().open()
250 })
251 .await;
252 let blob = Arc::new(MetricsBlob::new(blob, Arc::clone(&self.metrics)));
253 let blob = Arc::new(Tasked(blob));
254 let task = blob_rtt_latency_task(
255 Arc::clone(&blob),
256 Arc::clone(&self.metrics),
257 Self::PROMETHEUS_SCRAPE_INTERVAL,
258 )
259 .await;
260 let blob = BlobMemCache::new(&self.cfg, Arc::clone(&self.metrics), blob);
263 Arc::clone(&x.insert((RttLatencyTask(task.abort_on_drop()), blob)).1)
264 }
265 };
266 Ok(blob)
267 }
268}
269
270#[allow(clippy::unused_async)]
284async fn blob_rtt_latency_task(
285 blob: Arc<Tasked<MetricsBlob>>,
286 metrics: Arc<Metrics>,
287 measurement_interval: Duration,
288) -> JoinHandle<()> {
289 mz_ore::task::spawn(|| "persist::blob_rtt_latency", async move {
290 let mut next_measurement = tokio::time::Instant::now();
293 loop {
294 tokio::time::sleep_until(next_measurement).await;
295 let start = Instant::now();
296 match blob.get(BLOB_GET_LIVENESS_KEY).await {
297 Ok(_) => {
298 metrics.blob.rtt_latency.set(start.elapsed().as_secs_f64());
299 }
300 Err(_) => {
301 }
305 }
306 next_measurement = tokio::time::Instant::now() + measurement_interval;
307 }
308 })
309}
310
311#[allow(clippy::unused_async)]
325async fn consensus_rtt_latency_task(
326 consensus: Arc<Tasked<MetricsConsensus>>,
327 metrics: Arc<Metrics>,
328 measurement_interval: Duration,
329) -> JoinHandle<()> {
330 mz_ore::task::spawn(|| "persist::consensus_rtt_latency", async move {
331 let mut next_measurement = tokio::time::Instant::now();
334 loop {
335 tokio::time::sleep_until(next_measurement).await;
336 let start = Instant::now();
337 match consensus.head(CONSENSUS_HEAD_LIVENESS_KEY).await {
338 Ok(_) => {
339 metrics
340 .consensus
341 .rtt_latency
342 .set(start.elapsed().as_secs_f64());
343 }
344 Err(_) => {
345 }
349 }
350 next_measurement = tokio::time::Instant::now() + measurement_interval;
351 }
352 })
353}
354
355pub(crate) trait DynState: Debug + Send + Sync {
356 fn codecs(&self) -> (String, String, String, String, Option<CodecConcreteType>);
357 fn as_any(self: Arc<Self>) -> Arc<dyn Any + Send + Sync>;
358 fn push_diff(&self, diff: VersionedData);
359}
360
361impl<K, V, T, D> DynState for LockingTypedState<K, V, T, D>
362where
363 K: Codec,
364 V: Codec,
365 T: Timestamp + Lattice + Codec64 + Sync,
366 D: Codec64,
367{
368 fn codecs(&self) -> (String, String, String, String, Option<CodecConcreteType>) {
369 (
370 K::codec_name(),
371 V::codec_name(),
372 T::codec_name(),
373 D::codec_name(),
374 Some(CodecConcreteType(std::any::type_name::<(K, V, T, D)>())),
375 )
376 }
377
378 fn as_any(self: Arc<Self>) -> Arc<dyn Any + Send + Sync> {
379 self
380 }
381
382 fn push_diff(&self, diff: VersionedData) {
383 self.write_lock(&self.metrics.locks.applier_write, |state| {
384 let seqno_before = state.seqno;
385 state.apply_encoded_diffs(&self.cfg, &self.metrics, std::iter::once(&diff));
386 let seqno_after = state.seqno;
387 assert!(seqno_after >= seqno_before);
388
389 if seqno_before != seqno_after {
390 debug!(
391 "applied pushed diff {}. seqno {} -> {}.",
392 state.shard_id, seqno_before, state.seqno
393 );
394 self.shard_metrics.pubsub_push_diff_applied.inc();
395 } else {
396 debug!(
397 "failed to apply pushed diff {}. seqno {} vs diff {}",
398 state.shard_id, seqno_before, diff.seqno
399 );
400 if diff.seqno <= seqno_before {
401 self.shard_metrics.pubsub_push_diff_not_applied_stale.inc();
402 } else {
403 self.shard_metrics
404 .pubsub_push_diff_not_applied_out_of_order
405 .inc();
406 }
407 }
408 })
409 }
410}
411
412#[derive(Debug)]
423pub struct StateCache {
424 cfg: Arc<PersistConfig>,
425 pub(crate) metrics: Arc<Metrics>,
426 states: Arc<std::sync::Mutex<BTreeMap<ShardId, Arc<OnceCell<Weak<dyn DynState>>>>>>,
427 pubsub_sender: Arc<dyn PubSubSender>,
428}
429
430#[derive(Debug)]
431enum StateCacheInit {
432 Init(Arc<dyn DynState>),
433 NeedInit(Arc<OnceCell<Weak<dyn DynState>>>),
434}
435
436impl StateCache {
437 pub fn new(
439 cfg: &PersistConfig,
440 metrics: Arc<Metrics>,
441 pubsub_sender: Arc<dyn PubSubSender>,
442 ) -> Self {
443 StateCache {
444 cfg: Arc::new(cfg.clone()),
445 metrics,
446 states: Default::default(),
447 pubsub_sender,
448 }
449 }
450
451 #[cfg(test)]
452 pub(crate) fn new_no_metrics() -> Self {
453 Self::new(
454 &PersistConfig::new_for_tests(),
455 Arc::new(Metrics::new(
456 &PersistConfig::new_for_tests(),
457 &MetricsRegistry::new(),
458 )),
459 Arc::new(crate::rpc::NoopPubSubSender),
460 )
461 }
462
463 pub(crate) async fn get<K, V, T, D, F, InitFn>(
464 &self,
465 shard_id: ShardId,
466 mut init_fn: InitFn,
467 diagnostics: &Diagnostics,
468 ) -> Result<Arc<LockingTypedState<K, V, T, D>>, Box<CodecMismatch>>
469 where
470 K: Debug + Codec,
471 V: Debug + Codec,
472 T: Timestamp + Lattice + Codec64 + Sync,
473 D: Monoid + Codec64,
474 F: Future<Output = Result<TypedState<K, V, T, D>, Box<CodecMismatch>>>,
475 InitFn: FnMut() -> F,
476 {
477 loop {
478 let init = {
479 let mut states = self.states.lock().expect("lock poisoned");
480 let state = states.entry(shard_id).or_default();
481 match state.get() {
482 Some(once_val) => match once_val.upgrade() {
483 Some(x) => StateCacheInit::Init(x),
484 None => {
485 *state = Arc::new(OnceCell::new());
489 StateCacheInit::NeedInit(Arc::clone(state))
490 }
491 },
492 None => StateCacheInit::NeedInit(Arc::clone(state)),
493 }
494 };
495
496 let state = match init {
497 StateCacheInit::Init(x) => x,
498 StateCacheInit::NeedInit(init_once) => {
499 let mut did_init: Option<Arc<LockingTypedState<K, V, T, D>>> = None;
500 let state = init_once
501 .get_or_try_init::<Box<CodecMismatch>, _, _>(|| async {
502 let init_res = init_fn().await;
503 let state = Arc::new(LockingTypedState::new(
504 shard_id,
505 init_res?,
506 Arc::clone(&self.metrics),
507 Arc::clone(&self.cfg),
508 Arc::clone(&self.pubsub_sender).subscribe(&shard_id),
509 diagnostics,
510 ));
511 let ret = Arc::downgrade(&state);
512 did_init = Some(state);
513 let ret: Weak<dyn DynState> = ret;
514 Ok(ret)
515 })
516 .await?;
517 if let Some(x) = did_init {
518 return Ok(x);
522 }
523 let Some(state) = state.upgrade() else {
524 continue;
531 };
532 state
533 }
534 };
535
536 match Arc::clone(&state)
537 .as_any()
538 .downcast::<LockingTypedState<K, V, T, D>>()
539 {
540 Ok(x) => return Ok(x),
541 Err(_) => {
542 return Err(Box::new(CodecMismatch {
543 requested: (
544 K::codec_name(),
545 V::codec_name(),
546 T::codec_name(),
547 D::codec_name(),
548 Some(CodecConcreteType(std::any::type_name::<(K, V, T, D)>())),
549 ),
550 actual: state.codecs(),
551 }));
552 }
553 }
554 }
555 }
556
557 pub(crate) fn get_state_weak(&self, shard_id: &ShardId) -> Option<Weak<dyn DynState>> {
558 self.states
559 .lock()
560 .expect("lock")
561 .get(shard_id)
562 .and_then(|x| x.get())
563 .map(Weak::clone)
564 }
565
566 #[cfg(test)]
567 fn get_cached(&self, shard_id: &ShardId) -> Option<Arc<dyn DynState>> {
568 self.states
569 .lock()
570 .expect("lock")
571 .get(shard_id)
572 .and_then(|x| x.get())
573 .and_then(|x| x.upgrade())
574 }
575
576 #[cfg(test)]
577 fn initialized_count(&self) -> usize {
578 self.states
579 .lock()
580 .expect("lock")
581 .values()
582 .filter(|x| x.initialized())
583 .count()
584 }
585
586 #[cfg(test)]
587 fn strong_count(&self) -> usize {
588 self.states
589 .lock()
590 .expect("lock")
591 .values()
592 .filter(|x| x.get().map_or(false, |x| x.upgrade().is_some()))
593 .count()
594 }
595}
596
597pub(crate) struct LockingTypedState<K, V, T, D> {
601 shard_id: ShardId,
602 state: RwLock<TypedState<K, V, T, D>>,
603 notifier: StateWatchNotifier,
604 cfg: Arc<PersistConfig>,
605 metrics: Arc<Metrics>,
606 shard_metrics: Arc<ShardMetrics>,
607 update_semaphore: AwaitableState<Option<tokio::time::Instant>>,
608 schema_cache: Arc<dyn Any + Send + Sync>,
611 _subscription_token: Arc<ShardSubscriptionToken>,
612}
613
614impl<K, V, T: Debug, D> Debug for LockingTypedState<K, V, T, D> {
615 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
616 let LockingTypedState {
617 shard_id,
618 state,
619 notifier,
620 cfg: _cfg,
621 metrics: _metrics,
622 shard_metrics: _shard_metrics,
623 update_semaphore: _,
624 schema_cache: _schema_cache,
625 _subscription_token,
626 } = self;
627 f.debug_struct("LockingTypedState")
628 .field("shard_id", shard_id)
629 .field("state", state)
630 .field("notifier", notifier)
631 .finish()
632 }
633}
634
635impl<K: Codec, V: Codec, T, D> LockingTypedState<K, V, T, D> {
636 fn new(
637 shard_id: ShardId,
638 initial_state: TypedState<K, V, T, D>,
639 metrics: Arc<Metrics>,
640 cfg: Arc<PersistConfig>,
641 subscription_token: Arc<ShardSubscriptionToken>,
642 diagnostics: &Diagnostics,
643 ) -> Self {
644 Self {
645 shard_id,
646 notifier: StateWatchNotifier::new(Arc::clone(&metrics)),
647 state: RwLock::new(initial_state),
648 cfg: Arc::clone(&cfg),
649 shard_metrics: metrics.shards.shard(&shard_id, &diagnostics.shard_name),
650 update_semaphore: AwaitableState::new(None),
651 schema_cache: Arc::new(SchemaCacheMaps::<K, V>::new(&metrics.schema)),
652 metrics,
653 _subscription_token: subscription_token,
654 }
655 }
656
657 pub(crate) fn schema_cache(&self) -> Arc<SchemaCacheMaps<K, V>> {
658 Arc::clone(&self.schema_cache)
659 .downcast::<SchemaCacheMaps<K, V>>()
660 .expect("K and V match")
661 }
662}
663
664pub(crate) const STATE_UPDATE_LEASE_TIMEOUT: Config<Duration> = Config::new(
665 "persist_state_update_lease_timeout",
666 Duration::from_secs(1),
667 "The amount of time for a command to wait for a previous command to finish before executing. \
668 (If zero, commands will not wait for others to complete.) Higher values reduce database contention \
669 at the cost of higher worst-case latencies for individual requests.",
670);
671
672impl<K, V, T, D> LockingTypedState<K, V, T, D> {
673 pub(crate) fn shard_id(&self) -> &ShardId {
674 &self.shard_id
675 }
676
677 pub(crate) fn read_lock<R, F: FnMut(&TypedState<K, V, T, D>) -> R>(
678 &self,
679 metrics: &LockMetrics,
680 mut f: F,
681 ) -> R {
682 metrics.acquire_count.inc();
683 let state = match self.state.try_read() {
684 Ok(x) => x,
685 Err(TryLockError::WouldBlock) => {
686 metrics.blocking_acquire_count.inc();
687 let start = Instant::now();
688 let state = self.state.read().expect("lock poisoned");
689 metrics
690 .blocking_seconds
691 .inc_by(start.elapsed().as_secs_f64());
692 state
693 }
694 Err(TryLockError::Poisoned(err)) => panic!("state read lock poisoned: {}", err),
695 };
696 f(&state)
697 }
698
699 pub(crate) fn write_lock<R, F: FnOnce(&mut TypedState<K, V, T, D>) -> R>(
700 &self,
701 metrics: &LockMetrics,
702 f: F,
703 ) -> R {
704 metrics.acquire_count.inc();
705 let mut state = match self.state.try_write() {
706 Ok(x) => x,
707 Err(TryLockError::WouldBlock) => {
708 metrics.blocking_acquire_count.inc();
709 let start = Instant::now();
710 let state = self.state.write().expect("lock poisoned");
711 metrics
712 .blocking_seconds
713 .inc_by(start.elapsed().as_secs_f64());
714 state
715 }
716 Err(TryLockError::Poisoned(err)) => panic!("state read lock poisoned: {}", err),
717 };
718 let seqno_before = state.seqno;
719 let ret = f(&mut state);
720 let seqno_after = state.seqno;
721 debug_assert!(seqno_after >= seqno_before);
722 if seqno_after > seqno_before {
723 self.notifier.notify(seqno_after);
724 }
725 drop(state);
728 ret
729 }
730
731 pub(crate) async fn lease_for_update(&self) -> impl Drop {
739 use tokio::time::Instant;
740
741 let timeout = STATE_UPDATE_LEASE_TIMEOUT.get(&self.cfg);
742
743 struct DropLease(Option<(AwaitableState<Option<Instant>>, Instant)>);
744
745 impl Drop for DropLease {
746 fn drop(&mut self) {
747 if let Some((state, time)) = self.0.take() {
748 state.maybe_modify(|s| {
750 if s.is_some_and(|t| t == time) {
751 *s.get_mut() = None;
752 }
753 })
754 }
755 }
756 }
757
758 if timeout.is_zero() {
760 return DropLease(None);
761 }
762
763 let timeout_state = self.update_semaphore.clone();
764 loop {
765 let now = tokio::time::Instant::now();
766 let expires_at = now + timeout;
767 let maybe_leased = timeout_state.maybe_modify(|state| {
769 if let Some(other_expires_at) = **state
770 && other_expires_at > now
771 {
772 Err(other_expires_at)
774 } else {
775 *state.get_mut() = Some(expires_at);
776 Ok(())
777 }
778 });
779
780 match maybe_leased {
781 Ok(()) => {
782 break DropLease(Some((timeout_state, expires_at)));
783 }
784 Err(other_expires_at) => {
785 let _ = tokio::time::timeout_at(
790 other_expires_at,
791 timeout_state.wait_while(|s| s.is_some()),
792 )
793 .await;
794 }
795 }
796 }
797 }
798
799 pub(crate) fn notifier(&self) -> &StateWatchNotifier {
800 &self.notifier
801 }
802}
803
804#[cfg(test)]
805mod tests {
806 use std::ops::Deref;
807 use std::pin::pin;
808 use std::str::FromStr;
809 use std::sync::atomic::{AtomicBool, Ordering};
810
811 use super::*;
812 use crate::rpc::NoopPubSubSender;
813 use futures::stream::{FuturesUnordered, StreamExt};
814 use mz_build_info::DUMMY_BUILD_INFO;
815 use mz_ore::task::spawn;
816 use mz_ore::{assert_err, assert_none};
817 use tokio::sync::oneshot;
818
819 #[mz_ore::test(tokio::test)]
820 #[cfg_attr(miri, ignore)] async fn client_cache() {
822 let cache = PersistClientCache::new(
823 PersistConfig::new_for_tests(),
824 &MetricsRegistry::new(),
825 |_, _| PubSubClientConnection::noop(),
826 );
827 assert_eq!(cache.blob_by_uri.lock().await.len(), 0);
828 assert_eq!(cache.consensus_by_uri.lock().await.len(), 0);
829
830 let _ = cache
832 .open(PersistLocation {
833 blob_uri: SensitiveUrl::from_str("mem://blob_zero").expect("invalid URL"),
834 consensus_uri: SensitiveUrl::from_str("mem://consensus_zero").expect("invalid URL"),
835 })
836 .await
837 .expect("failed to open location");
838 assert_eq!(cache.blob_by_uri.lock().await.len(), 1);
839 assert_eq!(cache.consensus_by_uri.lock().await.len(), 1);
840
841 let _ = cache
844 .open(PersistLocation {
845 blob_uri: SensitiveUrl::from_str("mem://blob_one").expect("invalid URL"),
846 consensus_uri: SensitiveUrl::from_str("mem://consensus_zero").expect("invalid URL"),
847 })
848 .await
849 .expect("failed to open location");
850 assert_eq!(cache.blob_by_uri.lock().await.len(), 2);
851 assert_eq!(cache.consensus_by_uri.lock().await.len(), 1);
852
853 let _ = cache
855 .open(PersistLocation {
856 blob_uri: SensitiveUrl::from_str("mem://blob_one").expect("invalid URL"),
857 consensus_uri: SensitiveUrl::from_str("mem://consensus_one").expect("invalid URL"),
858 })
859 .await
860 .expect("failed to open location");
861 assert_eq!(cache.blob_by_uri.lock().await.len(), 2);
862 assert_eq!(cache.consensus_by_uri.lock().await.len(), 2);
863
864 let _ = cache
866 .open(PersistLocation {
867 blob_uri: SensitiveUrl::from_str("mem://blob_one?foo").expect("invalid URL"),
868 consensus_uri: SensitiveUrl::from_str("mem://consensus_one/bar")
869 .expect("invalid URL"),
870 })
871 .await
872 .expect("failed to open location");
873 assert_eq!(cache.blob_by_uri.lock().await.len(), 3);
874 assert_eq!(cache.consensus_by_uri.lock().await.len(), 3);
875
876 let _ = cache
878 .open(PersistLocation {
879 blob_uri: SensitiveUrl::from_str("mem://user@blob_one").expect("invalid URL"),
880 consensus_uri: SensitiveUrl::from_str("mem://@consensus_one:123")
881 .expect("invalid URL"),
882 })
883 .await
884 .expect("failed to open location");
885 assert_eq!(cache.blob_by_uri.lock().await.len(), 4);
886 assert_eq!(cache.consensus_by_uri.lock().await.len(), 4);
887 }
888
889 #[mz_ore::test(tokio::test)]
890 #[cfg_attr(miri, ignore)] async fn state_cache() {
892 mz_ore::test::init_logging();
893 fn new_state<K, V, T, D>(shard_id: ShardId) -> TypedState<K, V, T, D>
894 where
895 K: Codec,
896 V: Codec,
897 T: Timestamp + Lattice + Codec64,
898 D: Codec64,
899 {
900 TypedState::new(
901 DUMMY_BUILD_INFO.semver_version(),
902 shard_id,
903 "host".into(),
904 0,
905 )
906 }
907 fn assert_same<K, V, T, D>(
908 state1: &LockingTypedState<K, V, T, D>,
909 state2: &LockingTypedState<K, V, T, D>,
910 ) {
911 let pointer1 = format!("{:p}", state1.state.read().expect("lock").deref());
912 let pointer2 = format!("{:p}", state2.state.read().expect("lock").deref());
913 assert_eq!(pointer1, pointer2);
914 }
915
916 let s1 = ShardId::new();
917 let states = Arc::new(StateCache::new_no_metrics());
918
919 assert_eq!(states.states.lock().expect("lock").len(), 0);
921
922 let s = Arc::clone(&states);
924 let res = spawn(|| "test", async move {
925 s.get::<(), (), u64, i64, _, _>(
926 s1,
927 || async { panic!("forced panic") },
928 &Diagnostics::for_tests(),
929 )
930 .await
931 })
932 .into_tokio_handle()
933 .await;
934 assert_err!(res);
935 assert_eq!(states.initialized_count(), 0);
936
937 let res = states
939 .get::<(), (), u64, i64, _, _>(
940 s1,
941 || async {
942 Err(Box::new(CodecMismatch {
943 requested: ("".into(), "".into(), "".into(), "".into(), None),
944 actual: ("".into(), "".into(), "".into(), "".into(), None),
945 }))
946 },
947 &Diagnostics::for_tests(),
948 )
949 .await;
950 assert_err!(res);
951 assert_eq!(states.initialized_count(), 0);
952
953 let did_work = Arc::new(AtomicBool::new(false));
955 let s1_state1 = states
956 .get::<(), (), u64, i64, _, _>(
957 s1,
958 || {
959 let did_work = Arc::clone(&did_work);
960 async move {
961 did_work.store(true, Ordering::SeqCst);
962 Ok(new_state(s1))
963 }
964 },
965 &Diagnostics::for_tests(),
966 )
967 .await
968 .expect("should successfully initialize");
969 assert_eq!(did_work.load(Ordering::SeqCst), true);
970 assert_eq!(states.initialized_count(), 1);
971 assert_eq!(states.strong_count(), 1);
972
973 let did_work = Arc::new(AtomicBool::new(false));
975 let s1_state2 = states
976 .get::<(), (), u64, i64, _, _>(
977 s1,
978 || {
979 let did_work = Arc::clone(&did_work);
980 async move {
981 did_work.store(true, Ordering::SeqCst);
982 did_work.store(true, Ordering::SeqCst);
983 Ok(new_state(s1))
984 }
985 },
986 &Diagnostics::for_tests(),
987 )
988 .await
989 .expect("should successfully initialize");
990 assert_eq!(did_work.load(Ordering::SeqCst), false);
991 assert_eq!(states.initialized_count(), 1);
992 assert_eq!(states.strong_count(), 1);
993 assert_same(&s1_state1, &s1_state2);
994
995 let did_work = Arc::new(AtomicBool::new(false));
997 let res = states
998 .get::<String, (), u64, i64, _, _>(
999 s1,
1000 || {
1001 let did_work = Arc::clone(&did_work);
1002 async move {
1003 did_work.store(true, Ordering::SeqCst);
1004 Ok(new_state(s1))
1005 }
1006 },
1007 &Diagnostics::for_tests(),
1008 )
1009 .await;
1010 assert_eq!(did_work.load(Ordering::SeqCst), false);
1011 assert_eq!(
1012 format!("{}", res.expect_err("types shouldn't match")),
1013 "requested codecs (\"String\", \"()\", \"u64\", \"i64\", Some(CodecConcreteType(\"(alloc::string::String, (), u64, i64)\"))) did not match ones in durable storage (\"()\", \"()\", \"u64\", \"i64\", Some(CodecConcreteType(\"((), (), u64, i64)\")))"
1014 );
1015 assert_eq!(states.initialized_count(), 1);
1016 assert_eq!(states.strong_count(), 1);
1017
1018 let s2 = ShardId::new();
1020 let s2_state1 = states
1021 .get::<String, (), u64, i64, _, _>(
1022 s2,
1023 || async { Ok(new_state(s2)) },
1024 &Diagnostics::for_tests(),
1025 )
1026 .await
1027 .expect("should successfully initialize");
1028 assert_eq!(states.initialized_count(), 2);
1029 assert_eq!(states.strong_count(), 2);
1030 let s2_state2 = states
1031 .get::<String, (), u64, i64, _, _>(
1032 s2,
1033 || async { Ok(new_state(s2)) },
1034 &Diagnostics::for_tests(),
1035 )
1036 .await
1037 .expect("should successfully initialize");
1038 assert_same(&s2_state1, &s2_state2);
1039
1040 drop(s1_state1);
1043 assert_eq!(states.strong_count(), 2);
1044 drop(s1_state2);
1045 assert_eq!(states.strong_count(), 1);
1046 assert_eq!(states.initialized_count(), 2);
1047 assert_none!(states.get_cached(&s1));
1048
1049 let s1_state1 = states
1051 .get::<(), (), u64, i64, _, _>(
1052 s1,
1053 || async { Ok(new_state(s1)) },
1054 &Diagnostics::for_tests(),
1055 )
1056 .await
1057 .expect("should successfully initialize");
1058 assert_eq!(states.initialized_count(), 2);
1059 assert_eq!(states.strong_count(), 2);
1060 drop(s1_state1);
1061 assert_eq!(states.strong_count(), 1);
1062 }
1063
1064 #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
1065 #[cfg_attr(miri, ignore)] async fn state_cache_concurrency() {
1067 mz_ore::test::init_logging();
1068
1069 const COUNT: usize = 1000;
1070 let id = ShardId::new();
1071 let cache = StateCache::new_no_metrics();
1072 let diagnostics = Diagnostics::for_tests();
1073
1074 let mut futures = (0..COUNT)
1075 .map(|_| {
1076 cache.get::<(), (), u64, i64, _, _>(
1077 id,
1078 || async {
1079 Ok(TypedState::new(
1080 DUMMY_BUILD_INFO.semver_version(),
1081 id,
1082 "host".into(),
1083 0,
1084 ))
1085 },
1086 &diagnostics,
1087 )
1088 })
1089 .collect::<FuturesUnordered<_>>();
1090
1091 for _ in 0..COUNT {
1092 let _ = futures.next().await.unwrap();
1093 }
1094 }
1095
1096 #[mz_ore::test(tokio::test)]
1097 #[cfg_attr(miri, ignore)] async fn update_semaphore() {
1099 mz_ore::test::init_logging();
1102
1103 let shard_id = ShardId::new();
1104 let persist_config = Arc::new(PersistConfig::new_for_tests());
1105 let pubsub = Arc::new(NoopPubSubSender);
1106 let state: LockingTypedState<String, (), u64, i64> = LockingTypedState::new(
1107 shard_id,
1108 TypedState::new(
1109 DUMMY_BUILD_INFO.semver_version(),
1110 shard_id,
1111 "host".into(),
1112 0,
1113 ),
1114 Arc::new(Metrics::new(&*persist_config, &MetricsRegistry::new())),
1115 persist_config,
1116 pubsub.subscribe(&shard_id),
1117 &Diagnostics::for_tests(),
1118 );
1119
1120 let mk_future = || {
1123 let (tx, rx) = oneshot::channel();
1124 let future = async {
1125 let lease = state.lease_for_update().await;
1126 let () = rx.await.unwrap();
1127 drop(lease);
1128 };
1129 (future, tx)
1130 };
1131
1132 let (one, _one_tx) = mk_future();
1133 let (two, _two_tx) = mk_future();
1134 let (three, three_tx) = mk_future();
1135 let mut one = pin!(one);
1136 let mut two = pin!(two);
1137 let mut three = pin!(three);
1138
1139 tokio::select! { biased;
1141 _ = &mut one => { unreachable!() }
1142 _ = &mut two => { unreachable!() }
1143 _ = &mut three => { unreachable!() }
1144 _ = async {} => {}
1145 }
1146
1147 three_tx.send(()).unwrap();
1149
1150 tokio::select! { biased;
1153 _ = &mut one => { unreachable!() }
1154 _ = &mut three => { }
1155 }
1156 }
1157
1158 #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
1159 #[cfg_attr(miri, ignore)] async fn update_semaphore_stress() {
1161 mz_ore::test::init_logging();
1164
1165 const TIMEOUT: Duration = Duration::from_millis(100);
1166 const COUNT: u64 = 100;
1167
1168 let shard_id = ShardId::new();
1169 let persist_config = Arc::new(PersistConfig::new_for_tests());
1170 persist_config.set_config(&STATE_UPDATE_LEASE_TIMEOUT, TIMEOUT);
1171 let pubsub = Arc::new(NoopPubSubSender);
1172 let state: LockingTypedState<String, (), u64, i64> = LockingTypedState::new(
1173 shard_id,
1174 TypedState::new(
1175 DUMMY_BUILD_INFO.semver_version(),
1176 shard_id,
1177 "host".into(),
1178 0,
1179 ),
1180 Arc::new(Metrics::new(&*persist_config, &MetricsRegistry::new())),
1181 persist_config,
1182 pubsub.subscribe(&shard_id),
1183 &Diagnostics::for_tests(),
1184 );
1185
1186 let mut futures = (0..(COUNT * 3))
1187 .map(async |i| {
1188 state.lease_for_update().await;
1189 match i % 3 {
1191 0 => {
1192 let () = std::future::pending().await;
1193 }
1194 1 => {
1195 tokio::time::sleep(Duration::from_millis(i)).await;
1196 }
1197 _ => {
1198 tokio::time::sleep(Duration::from_millis(i) + TIMEOUT).await;
1199 }
1200 }
1201 })
1202 .collect::<FuturesUnordered<_>>();
1203
1204 for _ in 0..(COUNT * 2) {
1206 futures.next().await.unwrap();
1207 }
1208 }
1209}