mz_persist_client/
cache.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! A cache of [PersistClient]s indexed by [PersistLocation]s.
11
12use std::any::Any;
13use std::collections::BTreeMap;
14use std::collections::btree_map::Entry;
15use std::fmt::Debug;
16use std::future::Future;
17use std::sync::{Arc, RwLock, TryLockError, Weak};
18use std::time::{Duration, Instant};
19
20use differential_dataflow::difference::Semigroup;
21use differential_dataflow::lattice::Lattice;
22use mz_ore::instrument;
23use mz_ore::metrics::MetricsRegistry;
24use mz_ore::task::{AbortOnDropHandle, JoinHandle};
25use mz_ore::url::SensitiveUrl;
26use mz_persist::cfg::{BlobConfig, ConsensusConfig};
27use mz_persist::location::{
28    BLOB_GET_LIVENESS_KEY, Blob, CONSENSUS_HEAD_LIVENESS_KEY, Consensus, ExternalError, Tasked,
29    VersionedData,
30};
31use mz_persist_types::{Codec, Codec64};
32use timely::progress::Timestamp;
33use tokio::sync::{Mutex, OnceCell};
34use tracing::debug;
35
36use crate::async_runtime::IsolatedRuntime;
37use crate::error::{CodecConcreteType, CodecMismatch};
38use crate::internal::cache::BlobMemCache;
39use crate::internal::machine::retry_external;
40use crate::internal::metrics::{LockMetrics, Metrics, MetricsBlob, MetricsConsensus, ShardMetrics};
41use crate::internal::state::TypedState;
42use crate::internal::watch::StateWatchNotifier;
43use crate::rpc::{PubSubClientConnection, PubSubSender, ShardSubscriptionToken};
44use crate::schema::SchemaCacheMaps;
45use crate::{Diagnostics, PersistClient, PersistConfig, PersistLocation, ShardId};
46
47/// A cache of [PersistClient]s indexed by [PersistLocation]s.
48///
49/// There should be at most one of these per process. All production
50/// PersistClients should be created through this cache.
51///
52/// This is because, in production, persist is heavily limited by the number of
53/// server-side Postgres/Aurora connections. This cache allows PersistClients to
54/// share, for example, these Postgres connections.
55#[derive(Debug)]
56pub struct PersistClientCache {
57    /// The tunable knobs for persist.
58    pub cfg: PersistConfig,
59    pub(crate) metrics: Arc<Metrics>,
60    blob_by_uri: Mutex<BTreeMap<SensitiveUrl, (RttLatencyTask, Arc<dyn Blob>)>>,
61    consensus_by_uri: Mutex<BTreeMap<SensitiveUrl, (RttLatencyTask, Arc<dyn Consensus>)>>,
62    isolated_runtime: Arc<IsolatedRuntime>,
63    pub(crate) state_cache: Arc<StateCache>,
64    pubsub_sender: Arc<dyn PubSubSender>,
65    _pubsub_receiver_task: JoinHandle<()>,
66}
67
68#[derive(Debug)]
69struct RttLatencyTask(#[allow(dead_code)] AbortOnDropHandle<()>);
70
71impl PersistClientCache {
72    /// Returns a new [PersistClientCache].
73    pub fn new<F>(cfg: PersistConfig, registry: &MetricsRegistry, pubsub: F) -> Self
74    where
75        F: FnOnce(&PersistConfig, Arc<Metrics>) -> PubSubClientConnection,
76    {
77        let metrics = Arc::new(Metrics::new(&cfg, registry));
78        let pubsub_client = pubsub(&cfg, Arc::clone(&metrics));
79
80        let state_cache = Arc::new(StateCache::new(
81            &cfg,
82            Arc::clone(&metrics),
83            Arc::clone(&pubsub_client.sender),
84        ));
85        let _pubsub_receiver_task = crate::rpc::subscribe_state_cache_to_pubsub(
86            Arc::clone(&state_cache),
87            pubsub_client.receiver,
88        );
89        let isolated_runtime =
90            IsolatedRuntime::new(registry, Some(cfg.isolated_runtime_worker_threads));
91
92        PersistClientCache {
93            cfg,
94            metrics,
95            blob_by_uri: Mutex::new(BTreeMap::new()),
96            consensus_by_uri: Mutex::new(BTreeMap::new()),
97            isolated_runtime: Arc::new(isolated_runtime),
98            state_cache,
99            pubsub_sender: pubsub_client.sender,
100            _pubsub_receiver_task,
101        }
102    }
103
104    /// A test helper that returns a [PersistClientCache] disconnected from
105    /// metrics.
106    pub fn new_no_metrics() -> Self {
107        Self::new(
108            PersistConfig::new_for_tests(),
109            &MetricsRegistry::new(),
110            |_, _| PubSubClientConnection::noop(),
111        )
112    }
113
114    /// Returns the [PersistConfig] being used by this cache.
115    pub fn cfg(&self) -> &PersistConfig {
116        &self.cfg
117    }
118
119    /// Returns persist `Metrics`.
120    pub fn metrics(&self) -> &Arc<Metrics> {
121        &self.metrics
122    }
123
124    /// Returns `ShardMetrics` for the given shard.
125    pub fn shard_metrics(&self, shard_id: &ShardId, name: &str) -> Arc<ShardMetrics> {
126        self.metrics.shards.shard(shard_id, name)
127    }
128
129    /// Clears the state cache, allowing for tests with disconnected states.
130    ///
131    /// Only exposed for testing.
132    pub fn clear_state_cache(&mut self) {
133        self.state_cache = Arc::new(StateCache::new(
134            &self.cfg,
135            Arc::clone(&self.metrics),
136            Arc::clone(&self.pubsub_sender),
137        ))
138    }
139
140    /// Returns a new [PersistClient] for interfacing with persist shards made
141    /// durable to the given [PersistLocation].
142    ///
143    /// The same `location` may be used concurrently from multiple processes.
144    #[instrument(level = "debug")]
145    pub async fn open(&self, location: PersistLocation) -> Result<PersistClient, ExternalError> {
146        let blob = self.open_blob(location.blob_uri).await?;
147        let consensus = self.open_consensus(location.consensus_uri).await?;
148        PersistClient::new(
149            self.cfg.clone(),
150            blob,
151            consensus,
152            Arc::clone(&self.metrics),
153            Arc::clone(&self.isolated_runtime),
154            Arc::clone(&self.state_cache),
155            Arc::clone(&self.pubsub_sender),
156        )
157    }
158
159    // No sense in measuring rtt latencies more often than this.
160    const PROMETHEUS_SCRAPE_INTERVAL: Duration = Duration::from_secs(60);
161
162    async fn open_consensus(
163        &self,
164        consensus_uri: SensitiveUrl,
165    ) -> Result<Arc<dyn Consensus>, ExternalError> {
166        let mut consensus_by_uri = self.consensus_by_uri.lock().await;
167        let consensus = match consensus_by_uri.entry(consensus_uri) {
168            Entry::Occupied(x) => Arc::clone(&x.get().1),
169            Entry::Vacant(x) => {
170                // Intentionally hold the lock, so we don't double connect under
171                // concurrency.
172                let consensus = ConsensusConfig::try_from(
173                    x.key(),
174                    Box::new(self.cfg.clone()),
175                    self.metrics.postgres_consensus.clone(),
176                    Arc::clone(&self.cfg().configs),
177                )?;
178                let consensus =
179                    retry_external(&self.metrics.retries.external.consensus_open, || {
180                        consensus.clone().open()
181                    })
182                    .await;
183                let consensus =
184                    Arc::new(MetricsConsensus::new(consensus, Arc::clone(&self.metrics)));
185                let consensus = Arc::new(Tasked(consensus));
186                let task = consensus_rtt_latency_task(
187                    Arc::clone(&consensus),
188                    Arc::clone(&self.metrics),
189                    Self::PROMETHEUS_SCRAPE_INTERVAL,
190                )
191                .await;
192                Arc::clone(
193                    &x.insert((RttLatencyTask(task.abort_on_drop()), consensus))
194                        .1,
195                )
196            }
197        };
198        Ok(consensus)
199    }
200
201    async fn open_blob(&self, blob_uri: SensitiveUrl) -> Result<Arc<dyn Blob>, ExternalError> {
202        let mut blob_by_uri = self.blob_by_uri.lock().await;
203        let blob = match blob_by_uri.entry(blob_uri) {
204            Entry::Occupied(x) => Arc::clone(&x.get().1),
205            Entry::Vacant(x) => {
206                // Intentionally hold the lock, so we don't double connect under
207                // concurrency.
208                let blob = BlobConfig::try_from(
209                    x.key(),
210                    Box::new(self.cfg.clone()),
211                    self.metrics.s3_blob.clone(),
212                    Arc::clone(&self.cfg.configs),
213                )
214                .await?;
215                let blob = retry_external(&self.metrics.retries.external.blob_open, || {
216                    blob.clone().open()
217                })
218                .await;
219                let blob = Arc::new(MetricsBlob::new(blob, Arc::clone(&self.metrics)));
220                let blob = Arc::new(Tasked(blob));
221                let task = blob_rtt_latency_task(
222                    Arc::clone(&blob),
223                    Arc::clone(&self.metrics),
224                    Self::PROMETHEUS_SCRAPE_INTERVAL,
225                )
226                .await;
227                // This is intentionally "outside" (wrapping) MetricsBlob so
228                // that we don't include cached responses in blob metrics.
229                let blob = BlobMemCache::new(&self.cfg, Arc::clone(&self.metrics), blob);
230                Arc::clone(&x.insert((RttLatencyTask(task.abort_on_drop()), blob)).1)
231            }
232        };
233        Ok(blob)
234    }
235}
236
237/// Starts a task to periodically measure the persist-observed latency to
238/// consensus.
239///
240/// This is a task, rather than something like looking at the latencies of prod
241/// traffic, so that we minimize any issues around Futures not being polled
242/// promptly (as can and does happen with the Timely-polled Futures).
243///
244/// The caller is responsible for shutdown via aborting the `JoinHandle`.
245///
246/// No matter whether we wrap MetricsConsensus before or after we start up the
247/// rtt latency task, there's the possibility for it being confusing at some
248/// point. Err on the side of more data (including the latency measurements) to
249/// start.
250#[allow(clippy::unused_async)]
251async fn blob_rtt_latency_task(
252    blob: Arc<Tasked<MetricsBlob>>,
253    metrics: Arc<Metrics>,
254    measurement_interval: Duration,
255) -> JoinHandle<()> {
256    mz_ore::task::spawn(|| "persist::blob_rtt_latency", async move {
257        // Use the tokio Instant for next_measurement because the reclock tests
258        // mess with the tokio sleep clock.
259        let mut next_measurement = tokio::time::Instant::now();
260        loop {
261            tokio::time::sleep_until(next_measurement).await;
262            let start = Instant::now();
263            match blob.get(BLOB_GET_LIVENESS_KEY).await {
264                Ok(_) => {
265                    metrics.blob.rtt_latency.set(start.elapsed().as_secs_f64());
266                }
267                Err(_) => {
268                    // Don't spam retries if this returns an error. We're
269                    // guaranteed by the method signature that we've already got
270                    // metrics coverage of these, so we'll count the errors.
271                }
272            }
273            next_measurement = tokio::time::Instant::now() + measurement_interval;
274        }
275    })
276}
277
278/// Starts a task to periodically measure the persist-observed latency to
279/// consensus.
280///
281/// This is a task, rather than something like looking at the latencies of prod
282/// traffic, so that we minimize any issues around Futures not being polled
283/// promptly (as can and does happen with the Timely-polled Futures).
284///
285/// The caller is responsible for shutdown via aborting the `JoinHandle`.
286///
287/// No matter whether we wrap MetricsConsensus before or after we start up the
288/// rtt latency task, there's the possibility for it being confusing at some
289/// point. Err on the side of more data (including the latency measurements) to
290/// start.
291#[allow(clippy::unused_async)]
292async fn consensus_rtt_latency_task(
293    consensus: Arc<Tasked<MetricsConsensus>>,
294    metrics: Arc<Metrics>,
295    measurement_interval: Duration,
296) -> JoinHandle<()> {
297    mz_ore::task::spawn(|| "persist::consensus_rtt_latency", async move {
298        // Use the tokio Instant for next_measurement because the reclock tests
299        // mess with the tokio sleep clock.
300        let mut next_measurement = tokio::time::Instant::now();
301        loop {
302            tokio::time::sleep_until(next_measurement).await;
303            let start = Instant::now();
304            match consensus.head(CONSENSUS_HEAD_LIVENESS_KEY).await {
305                Ok(_) => {
306                    metrics
307                        .consensus
308                        .rtt_latency
309                        .set(start.elapsed().as_secs_f64());
310                }
311                Err(_) => {
312                    // Don't spam retries if this returns an error. We're
313                    // guaranteed by the method signature that we've already got
314                    // metrics coverage of these, so we'll count the errors.
315                }
316            }
317            next_measurement = tokio::time::Instant::now() + measurement_interval;
318        }
319    })
320}
321
322pub(crate) trait DynState: Debug + Send + Sync {
323    fn codecs(&self) -> (String, String, String, String, Option<CodecConcreteType>);
324    fn as_any(self: Arc<Self>) -> Arc<dyn Any + Send + Sync>;
325    fn push_diff(&self, diff: VersionedData);
326}
327
328impl<K, V, T, D> DynState for LockingTypedState<K, V, T, D>
329where
330    K: Codec,
331    V: Codec,
332    T: Timestamp + Lattice + Codec64 + Sync,
333    D: Codec64,
334{
335    fn codecs(&self) -> (String, String, String, String, Option<CodecConcreteType>) {
336        (
337            K::codec_name(),
338            V::codec_name(),
339            T::codec_name(),
340            D::codec_name(),
341            Some(CodecConcreteType(std::any::type_name::<(K, V, T, D)>())),
342        )
343    }
344
345    fn as_any(self: Arc<Self>) -> Arc<dyn Any + Send + Sync> {
346        self
347    }
348
349    fn push_diff(&self, diff: VersionedData) {
350        self.write_lock(&self.metrics.locks.applier_write, |state| {
351            let seqno_before = state.seqno;
352            state.apply_encoded_diffs(&self.cfg, &self.metrics, std::iter::once(&diff));
353            let seqno_after = state.seqno;
354            assert!(seqno_after >= seqno_before);
355
356            if seqno_before != seqno_after {
357                debug!(
358                    "applied pushed diff {}. seqno {} -> {}.",
359                    state.shard_id, seqno_before, state.seqno
360                );
361                self.shard_metrics.pubsub_push_diff_applied.inc();
362            } else {
363                debug!(
364                    "failed to apply pushed diff {}. seqno {} vs diff {}",
365                    state.shard_id, seqno_before, diff.seqno
366                );
367                if diff.seqno <= seqno_before {
368                    self.shard_metrics.pubsub_push_diff_not_applied_stale.inc();
369                } else {
370                    self.shard_metrics
371                        .pubsub_push_diff_not_applied_out_of_order
372                        .inc();
373                }
374            }
375        })
376    }
377}
378
379/// A cache of `TypedState`, shared between all machines for that shard.
380///
381/// This is shared between all machines that come out of the same
382/// [PersistClientCache], but in production there is one of those per process,
383/// so in practice, we have one copy of state per shard per process.
384///
385/// The mutex contention between commands is not an issue, because if two
386/// command for the same shard are executing concurrently, only one can win
387/// anyway, the other will retry. With the mutex, we even get to avoid the retry
388/// if the racing commands are on the same process.
389#[derive(Debug)]
390pub struct StateCache {
391    cfg: Arc<PersistConfig>,
392    pub(crate) metrics: Arc<Metrics>,
393    states: Arc<std::sync::Mutex<BTreeMap<ShardId, Arc<OnceCell<Weak<dyn DynState>>>>>>,
394    pubsub_sender: Arc<dyn PubSubSender>,
395}
396
397#[derive(Debug)]
398enum StateCacheInit {
399    Init(Arc<dyn DynState>),
400    NeedInit(Arc<OnceCell<Weak<dyn DynState>>>),
401}
402
403impl StateCache {
404    /// Returns a new StateCache.
405    pub fn new(
406        cfg: &PersistConfig,
407        metrics: Arc<Metrics>,
408        pubsub_sender: Arc<dyn PubSubSender>,
409    ) -> Self {
410        StateCache {
411            cfg: Arc::new(cfg.clone()),
412            metrics,
413            states: Default::default(),
414            pubsub_sender,
415        }
416    }
417
418    #[cfg(test)]
419    pub(crate) fn new_no_metrics() -> Self {
420        Self::new(
421            &PersistConfig::new_for_tests(),
422            Arc::new(Metrics::new(
423                &PersistConfig::new_for_tests(),
424                &MetricsRegistry::new(),
425            )),
426            Arc::new(crate::rpc::NoopPubSubSender),
427        )
428    }
429
430    pub(crate) async fn get<K, V, T, D, F, InitFn>(
431        &self,
432        shard_id: ShardId,
433        mut init_fn: InitFn,
434        diagnostics: &Diagnostics,
435    ) -> Result<Arc<LockingTypedState<K, V, T, D>>, Box<CodecMismatch>>
436    where
437        K: Debug + Codec,
438        V: Debug + Codec,
439        T: Timestamp + Lattice + Codec64 + Sync,
440        D: Semigroup + Codec64,
441        F: Future<Output = Result<TypedState<K, V, T, D>, Box<CodecMismatch>>>,
442        InitFn: FnMut() -> F,
443    {
444        loop {
445            let init = {
446                let mut states = self.states.lock().expect("lock poisoned");
447                let state = states.entry(shard_id).or_default();
448                match state.get() {
449                    Some(once_val) => match once_val.upgrade() {
450                        Some(x) => StateCacheInit::Init(x),
451                        None => {
452                            // If the Weak has lost the ability to upgrade,
453                            // we've dropped the State and it's gone. Clear the
454                            // OnceCell and init a new one.
455                            *state = Arc::new(OnceCell::new());
456                            StateCacheInit::NeedInit(Arc::clone(state))
457                        }
458                    },
459                    None => StateCacheInit::NeedInit(Arc::clone(state)),
460                }
461            };
462
463            let state = match init {
464                StateCacheInit::Init(x) => x,
465                StateCacheInit::NeedInit(init_once) => {
466                    let mut did_init: Option<Arc<LockingTypedState<K, V, T, D>>> = None;
467                    let state = init_once
468                        .get_or_try_init::<Box<CodecMismatch>, _, _>(|| async {
469                            let init_res = init_fn().await;
470                            let state = Arc::new(LockingTypedState::new(
471                                shard_id,
472                                init_res?,
473                                Arc::clone(&self.metrics),
474                                Arc::clone(&self.cfg),
475                                Arc::clone(&self.pubsub_sender).subscribe(&shard_id),
476                                diagnostics,
477                            ));
478                            let ret = Arc::downgrade(&state);
479                            did_init = Some(state);
480                            let ret: Weak<dyn DynState> = ret;
481                            Ok(ret)
482                        })
483                        .await?;
484                    if let Some(x) = did_init {
485                        // We actually did the init work, don't bother casting back
486                        // the type erased and weak version. Additionally, inform
487                        // any listeners of this new state.
488                        return Ok(x);
489                    }
490                    let Some(state) = state.upgrade() else {
491                        // Race condition. Between when we first checked the
492                        // OnceCell and the `get_or_try_init` call, (1) the
493                        // initialization finished, (2) the other user dropped
494                        // the strong ref, and (3) the Arc noticed it was down
495                        // to only weak refs and dropped the value. Nothing we
496                        // can do except try again.
497                        continue;
498                    };
499                    state
500                }
501            };
502
503            match Arc::clone(&state)
504                .as_any()
505                .downcast::<LockingTypedState<K, V, T, D>>()
506            {
507                Ok(x) => return Ok(x),
508                Err(_) => {
509                    return Err(Box::new(CodecMismatch {
510                        requested: (
511                            K::codec_name(),
512                            V::codec_name(),
513                            T::codec_name(),
514                            D::codec_name(),
515                            Some(CodecConcreteType(std::any::type_name::<(K, V, T, D)>())),
516                        ),
517                        actual: state.codecs(),
518                    }));
519                }
520            }
521        }
522    }
523
524    pub(crate) fn get_state_weak(&self, shard_id: &ShardId) -> Option<Weak<dyn DynState>> {
525        self.states
526            .lock()
527            .expect("lock")
528            .get(shard_id)
529            .and_then(|x| x.get())
530            .map(Weak::clone)
531    }
532
533    #[cfg(test)]
534    fn get_cached(&self, shard_id: &ShardId) -> Option<Arc<dyn DynState>> {
535        self.states
536            .lock()
537            .expect("lock")
538            .get(shard_id)
539            .and_then(|x| x.get())
540            .and_then(|x| x.upgrade())
541    }
542
543    #[cfg(test)]
544    fn initialized_count(&self) -> usize {
545        self.states
546            .lock()
547            .expect("lock")
548            .values()
549            .filter(|x| x.initialized())
550            .count()
551    }
552
553    #[cfg(test)]
554    fn strong_count(&self) -> usize {
555        self.states
556            .lock()
557            .expect("lock")
558            .values()
559            .filter(|x| x.get().map_or(false, |x| x.upgrade().is_some()))
560            .count()
561    }
562}
563
564/// A locked decorator for TypedState that abstracts out the specific lock implementation used.
565/// Guards the private lock with public accessor fns to make locking scopes more explicit and
566/// simpler to reason about.
567pub(crate) struct LockingTypedState<K, V, T, D> {
568    shard_id: ShardId,
569    state: RwLock<TypedState<K, V, T, D>>,
570    notifier: StateWatchNotifier,
571    cfg: Arc<PersistConfig>,
572    metrics: Arc<Metrics>,
573    shard_metrics: Arc<ShardMetrics>,
574    /// A [SchemaCacheMaps<K, V>], but stored as an Any so the `: Codec` bounds
575    /// don't propagate to basically every struct in persist.
576    schema_cache: Arc<dyn Any + Send + Sync>,
577    _subscription_token: Arc<ShardSubscriptionToken>,
578}
579
580impl<K, V, T: Debug, D> Debug for LockingTypedState<K, V, T, D> {
581    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
582        let LockingTypedState {
583            shard_id,
584            state,
585            notifier,
586            cfg: _cfg,
587            metrics: _metrics,
588            shard_metrics: _shard_metrics,
589            schema_cache: _schema_cache,
590            _subscription_token,
591        } = self;
592        f.debug_struct("LockingTypedState")
593            .field("shard_id", shard_id)
594            .field("state", state)
595            .field("notifier", notifier)
596            .finish()
597    }
598}
599
600impl<K: Codec, V: Codec, T, D> LockingTypedState<K, V, T, D> {
601    fn new(
602        shard_id: ShardId,
603        initial_state: TypedState<K, V, T, D>,
604        metrics: Arc<Metrics>,
605        cfg: Arc<PersistConfig>,
606        subscription_token: Arc<ShardSubscriptionToken>,
607        diagnostics: &Diagnostics,
608    ) -> Self {
609        Self {
610            shard_id,
611            notifier: StateWatchNotifier::new(Arc::clone(&metrics)),
612            state: RwLock::new(initial_state),
613            cfg: Arc::clone(&cfg),
614            shard_metrics: metrics.shards.shard(&shard_id, &diagnostics.shard_name),
615            schema_cache: Arc::new(SchemaCacheMaps::<K, V>::new(&metrics.schema)),
616            metrics,
617            _subscription_token: subscription_token,
618        }
619    }
620
621    pub(crate) fn schema_cache(&self) -> Arc<SchemaCacheMaps<K, V>> {
622        Arc::clone(&self.schema_cache)
623            .downcast::<SchemaCacheMaps<K, V>>()
624            .expect("K and V match")
625    }
626}
627
628impl<K, V, T, D> LockingTypedState<K, V, T, D> {
629    pub(crate) fn shard_id(&self) -> &ShardId {
630        &self.shard_id
631    }
632
633    pub(crate) fn read_lock<R, F: FnMut(&TypedState<K, V, T, D>) -> R>(
634        &self,
635        metrics: &LockMetrics,
636        mut f: F,
637    ) -> R {
638        metrics.acquire_count.inc();
639        let state = match self.state.try_read() {
640            Ok(x) => x,
641            Err(TryLockError::WouldBlock) => {
642                metrics.blocking_acquire_count.inc();
643                let start = Instant::now();
644                let state = self.state.read().expect("lock poisoned");
645                metrics
646                    .blocking_seconds
647                    .inc_by(start.elapsed().as_secs_f64());
648                state
649            }
650            Err(TryLockError::Poisoned(err)) => panic!("state read lock poisoned: {}", err),
651        };
652        f(&state)
653    }
654
655    pub(crate) fn write_lock<R, F: FnOnce(&mut TypedState<K, V, T, D>) -> R>(
656        &self,
657        metrics: &LockMetrics,
658        f: F,
659    ) -> R {
660        metrics.acquire_count.inc();
661        let mut state = match self.state.try_write() {
662            Ok(x) => x,
663            Err(TryLockError::WouldBlock) => {
664                metrics.blocking_acquire_count.inc();
665                let start = Instant::now();
666                let state = self.state.write().expect("lock poisoned");
667                metrics
668                    .blocking_seconds
669                    .inc_by(start.elapsed().as_secs_f64());
670                state
671            }
672            Err(TryLockError::Poisoned(err)) => panic!("state read lock poisoned: {}", err),
673        };
674        let seqno_before = state.seqno;
675        let ret = f(&mut state);
676        let seqno_after = state.seqno;
677        debug_assert!(seqno_after >= seqno_before);
678        if seqno_after > seqno_before {
679            self.notifier.notify(seqno_after);
680        }
681        // For now, make sure to notify while under lock. It's possible to move
682        // this out of the lock window, see [StateWatchNotifier::notify].
683        drop(state);
684        ret
685    }
686
687    pub(crate) fn notifier(&self) -> &StateWatchNotifier {
688        &self.notifier
689    }
690}
691
692#[cfg(test)]
693mod tests {
694    use std::ops::Deref;
695    use std::str::FromStr;
696    use std::sync::atomic::{AtomicBool, Ordering};
697
698    use futures::stream::{FuturesUnordered, StreamExt};
699    use mz_build_info::DUMMY_BUILD_INFO;
700    use mz_ore::task::spawn;
701    use mz_ore::{assert_err, assert_none};
702
703    use super::*;
704
705    #[mz_ore::test(tokio::test)]
706    #[cfg_attr(miri, ignore)] // unsupported operation: returning ready events from epoll_wait is not yet implemented
707    async fn client_cache() {
708        let cache = PersistClientCache::new(
709            PersistConfig::new_for_tests(),
710            &MetricsRegistry::new(),
711            |_, _| PubSubClientConnection::noop(),
712        );
713        assert_eq!(cache.blob_by_uri.lock().await.len(), 0);
714        assert_eq!(cache.consensus_by_uri.lock().await.len(), 0);
715
716        // Opening a location on an empty cache saves the results.
717        let _ = cache
718            .open(PersistLocation {
719                blob_uri: SensitiveUrl::from_str("mem://blob_zero").expect("invalid URL"),
720                consensus_uri: SensitiveUrl::from_str("mem://consensus_zero").expect("invalid URL"),
721            })
722            .await
723            .expect("failed to open location");
724        assert_eq!(cache.blob_by_uri.lock().await.len(), 1);
725        assert_eq!(cache.consensus_by_uri.lock().await.len(), 1);
726
727        // Opening a location with an already opened consensus reuses it, even
728        // if the blob is different.
729        let _ = cache
730            .open(PersistLocation {
731                blob_uri: SensitiveUrl::from_str("mem://blob_one").expect("invalid URL"),
732                consensus_uri: SensitiveUrl::from_str("mem://consensus_zero").expect("invalid URL"),
733            })
734            .await
735            .expect("failed to open location");
736        assert_eq!(cache.blob_by_uri.lock().await.len(), 2);
737        assert_eq!(cache.consensus_by_uri.lock().await.len(), 1);
738
739        // Ditto the other way.
740        let _ = cache
741            .open(PersistLocation {
742                blob_uri: SensitiveUrl::from_str("mem://blob_one").expect("invalid URL"),
743                consensus_uri: SensitiveUrl::from_str("mem://consensus_one").expect("invalid URL"),
744            })
745            .await
746            .expect("failed to open location");
747        assert_eq!(cache.blob_by_uri.lock().await.len(), 2);
748        assert_eq!(cache.consensus_by_uri.lock().await.len(), 2);
749
750        // Query params and path matter, so we get new instances.
751        let _ = cache
752            .open(PersistLocation {
753                blob_uri: SensitiveUrl::from_str("mem://blob_one?foo").expect("invalid URL"),
754                consensus_uri: SensitiveUrl::from_str("mem://consensus_one/bar")
755                    .expect("invalid URL"),
756            })
757            .await
758            .expect("failed to open location");
759        assert_eq!(cache.blob_by_uri.lock().await.len(), 3);
760        assert_eq!(cache.consensus_by_uri.lock().await.len(), 3);
761
762        // User info and port also matter, so we get new instances.
763        let _ = cache
764            .open(PersistLocation {
765                blob_uri: SensitiveUrl::from_str("mem://user@blob_one").expect("invalid URL"),
766                consensus_uri: SensitiveUrl::from_str("mem://@consensus_one:123")
767                    .expect("invalid URL"),
768            })
769            .await
770            .expect("failed to open location");
771        assert_eq!(cache.blob_by_uri.lock().await.len(), 4);
772        assert_eq!(cache.consensus_by_uri.lock().await.len(), 4);
773    }
774
775    #[mz_ore::test(tokio::test)]
776    #[cfg_attr(miri, ignore)] // unsupported operation: returning ready events from epoll_wait is not yet implemented
777    async fn state_cache() {
778        mz_ore::test::init_logging();
779        fn new_state<K, V, T, D>(shard_id: ShardId) -> TypedState<K, V, T, D>
780        where
781            K: Codec,
782            V: Codec,
783            T: Timestamp + Lattice + Codec64,
784            D: Codec64,
785        {
786            TypedState::new(
787                DUMMY_BUILD_INFO.semver_version(),
788                shard_id,
789                "host".into(),
790                0,
791            )
792        }
793        fn assert_same<K, V, T, D>(
794            state1: &LockingTypedState<K, V, T, D>,
795            state2: &LockingTypedState<K, V, T, D>,
796        ) {
797            let pointer1 = format!("{:p}", state1.state.read().expect("lock").deref());
798            let pointer2 = format!("{:p}", state2.state.read().expect("lock").deref());
799            assert_eq!(pointer1, pointer2);
800        }
801
802        let s1 = ShardId::new();
803        let states = Arc::new(StateCache::new_no_metrics());
804
805        // The cache starts empty.
806        assert_eq!(states.states.lock().expect("lock").len(), 0);
807
808        // Panic'ing during init_fn .
809        let s = Arc::clone(&states);
810        let res = spawn(|| "test", async move {
811            s.get::<(), (), u64, i64, _, _>(
812                s1,
813                || async { panic!("forced panic") },
814                &Diagnostics::for_tests(),
815            )
816            .await
817        })
818        .await;
819        assert_err!(res);
820        assert_eq!(states.initialized_count(), 0);
821
822        // Returning an error from init_fn doesn't initialize an entry in the cache.
823        let res = states
824            .get::<(), (), u64, i64, _, _>(
825                s1,
826                || async {
827                    Err(Box::new(CodecMismatch {
828                        requested: ("".into(), "".into(), "".into(), "".into(), None),
829                        actual: ("".into(), "".into(), "".into(), "".into(), None),
830                    }))
831                },
832                &Diagnostics::for_tests(),
833            )
834            .await;
835        assert_err!(res);
836        assert_eq!(states.initialized_count(), 0);
837
838        // Initialize one shard.
839        let did_work = Arc::new(AtomicBool::new(false));
840        let s1_state1 = states
841            .get::<(), (), u64, i64, _, _>(
842                s1,
843                || {
844                    let did_work = Arc::clone(&did_work);
845                    async move {
846                        did_work.store(true, Ordering::SeqCst);
847                        Ok(new_state(s1))
848                    }
849                },
850                &Diagnostics::for_tests(),
851            )
852            .await
853            .expect("should successfully initialize");
854        assert_eq!(did_work.load(Ordering::SeqCst), true);
855        assert_eq!(states.initialized_count(), 1);
856        assert_eq!(states.strong_count(), 1);
857
858        // Trying to initialize it again does no work and returns the same state.
859        let did_work = Arc::new(AtomicBool::new(false));
860        let s1_state2 = states
861            .get::<(), (), u64, i64, _, _>(
862                s1,
863                || {
864                    let did_work = Arc::clone(&did_work);
865                    async move {
866                        did_work.store(true, Ordering::SeqCst);
867                        did_work.store(true, Ordering::SeqCst);
868                        Ok(new_state(s1))
869                    }
870                },
871                &Diagnostics::for_tests(),
872            )
873            .await
874            .expect("should successfully initialize");
875        assert_eq!(did_work.load(Ordering::SeqCst), false);
876        assert_eq!(states.initialized_count(), 1);
877        assert_eq!(states.strong_count(), 1);
878        assert_same(&s1_state1, &s1_state2);
879
880        // Trying to initialize with different types doesn't work.
881        let did_work = Arc::new(AtomicBool::new(false));
882        let res = states
883            .get::<String, (), u64, i64, _, _>(
884                s1,
885                || {
886                    let did_work = Arc::clone(&did_work);
887                    async move {
888                        did_work.store(true, Ordering::SeqCst);
889                        Ok(new_state(s1))
890                    }
891                },
892                &Diagnostics::for_tests(),
893            )
894            .await;
895        assert_eq!(did_work.load(Ordering::SeqCst), false);
896        assert_eq!(
897            format!("{}", res.expect_err("types shouldn't match")),
898            "requested codecs (\"String\", \"()\", \"u64\", \"i64\", Some(CodecConcreteType(\"(alloc::string::String, (), u64, i64)\"))) did not match ones in durable storage (\"()\", \"()\", \"u64\", \"i64\", Some(CodecConcreteType(\"((), (), u64, i64)\")))"
899        );
900        assert_eq!(states.initialized_count(), 1);
901        assert_eq!(states.strong_count(), 1);
902
903        // We can add a shard of a different type.
904        let s2 = ShardId::new();
905        let s2_state1 = states
906            .get::<String, (), u64, i64, _, _>(
907                s2,
908                || async { Ok(new_state(s2)) },
909                &Diagnostics::for_tests(),
910            )
911            .await
912            .expect("should successfully initialize");
913        assert_eq!(states.initialized_count(), 2);
914        assert_eq!(states.strong_count(), 2);
915        let s2_state2 = states
916            .get::<String, (), u64, i64, _, _>(
917                s2,
918                || async { Ok(new_state(s2)) },
919                &Diagnostics::for_tests(),
920            )
921            .await
922            .expect("should successfully initialize");
923        assert_same(&s2_state1, &s2_state2);
924
925        // The cache holds weak references to State so we reclaim memory if the
926        // shards stops being used.
927        drop(s1_state1);
928        assert_eq!(states.strong_count(), 2);
929        drop(s1_state2);
930        assert_eq!(states.strong_count(), 1);
931        assert_eq!(states.initialized_count(), 2);
932        assert_none!(states.get_cached(&s1));
933
934        // But we can re-init that shard if necessary.
935        let s1_state1 = states
936            .get::<(), (), u64, i64, _, _>(
937                s1,
938                || async { Ok(new_state(s1)) },
939                &Diagnostics::for_tests(),
940            )
941            .await
942            .expect("should successfully initialize");
943        assert_eq!(states.initialized_count(), 2);
944        assert_eq!(states.strong_count(), 2);
945        drop(s1_state1);
946        assert_eq!(states.strong_count(), 1);
947    }
948
949    #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
950    #[cfg_attr(miri, ignore)] // too slow
951    async fn state_cache_concurrency() {
952        mz_ore::test::init_logging();
953
954        const COUNT: usize = 1000;
955        let id = ShardId::new();
956        let cache = StateCache::new_no_metrics();
957        let diagnostics = Diagnostics::for_tests();
958
959        let mut futures = (0..COUNT)
960            .map(|_| {
961                cache.get::<(), (), u64, i64, _, _>(
962                    id,
963                    || async {
964                        Ok(TypedState::new(
965                            DUMMY_BUILD_INFO.semver_version(),
966                            id,
967                            "host".into(),
968                            0,
969                        ))
970                    },
971                    &diagnostics,
972                )
973            })
974            .collect::<FuturesUnordered<_>>();
975
976        for _ in 0..COUNT {
977            let _ = futures.next().await.unwrap();
978        }
979    }
980}