mz_persist_client/
cache.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! A cache of [PersistClient]s indexed by [PersistLocation]s.
11
12use std::any::Any;
13use std::collections::BTreeMap;
14use std::collections::btree_map::Entry;
15use std::fmt::Debug;
16use std::future::Future;
17use std::sync::{Arc, RwLock, TryLockError, Weak};
18use std::time::{Duration, Instant};
19
20use differential_dataflow::difference::Semigroup;
21use differential_dataflow::lattice::Lattice;
22use mz_ore::instrument;
23use mz_ore::metrics::MetricsRegistry;
24use mz_ore::task::{AbortOnDropHandle, JoinHandle};
25use mz_ore::url::SensitiveUrl;
26use mz_persist::cfg::{BlobConfig, ConsensusConfig};
27use mz_persist::location::{
28    BLOB_GET_LIVENESS_KEY, Blob, CONSENSUS_HEAD_LIVENESS_KEY, Consensus, ExternalError, Tasked,
29    VersionedData,
30};
31use mz_persist_types::{Codec, Codec64};
32use timely::progress::Timestamp;
33use tokio::sync::{Mutex, OnceCell};
34use tracing::debug;
35
36use crate::async_runtime::IsolatedRuntime;
37use crate::error::{CodecConcreteType, CodecMismatch};
38use crate::internal::cache::BlobMemCache;
39use crate::internal::machine::retry_external;
40use crate::internal::metrics::{LockMetrics, Metrics, MetricsBlob, MetricsConsensus, ShardMetrics};
41use crate::internal::state::TypedState;
42use crate::internal::watch::StateWatchNotifier;
43use crate::rpc::{PubSubClientConnection, PubSubSender, ShardSubscriptionToken};
44use crate::schema::SchemaCacheMaps;
45use crate::{Diagnostics, PersistClient, PersistConfig, PersistLocation, ShardId};
46
47/// A cache of [PersistClient]s indexed by [PersistLocation]s.
48///
49/// There should be at most one of these per process. All production
50/// PersistClients should be created through this cache.
51///
52/// This is because, in production, persist is heavily limited by the number of
53/// server-side Postgres/Aurora connections. This cache allows PersistClients to
54/// share, for example, these Postgres connections.
55#[derive(Debug)]
56pub struct PersistClientCache {
57    /// The tunable knobs for persist.
58    pub cfg: PersistConfig,
59    pub(crate) metrics: Arc<Metrics>,
60    blob_by_uri: Mutex<BTreeMap<SensitiveUrl, (RttLatencyTask, Arc<dyn Blob>)>>,
61    consensus_by_uri: Mutex<BTreeMap<SensitiveUrl, (RttLatencyTask, Arc<dyn Consensus>)>>,
62    isolated_runtime: Arc<IsolatedRuntime>,
63    pub(crate) state_cache: Arc<StateCache>,
64    pubsub_sender: Arc<dyn PubSubSender>,
65    _pubsub_receiver_task: JoinHandle<()>,
66}
67
68#[derive(Debug)]
69struct RttLatencyTask(#[allow(dead_code)] AbortOnDropHandle<()>);
70
71impl PersistClientCache {
72    /// Returns a new [PersistClientCache].
73    pub fn new<F>(cfg: PersistConfig, registry: &MetricsRegistry, pubsub: F) -> Self
74    where
75        F: FnOnce(&PersistConfig, Arc<Metrics>) -> PubSubClientConnection,
76    {
77        let metrics = Arc::new(Metrics::new(&cfg, registry));
78        let pubsub_client = pubsub(&cfg, Arc::clone(&metrics));
79
80        let state_cache = Arc::new(StateCache::new(
81            &cfg,
82            Arc::clone(&metrics),
83            Arc::clone(&pubsub_client.sender),
84        ));
85        let _pubsub_receiver_task = crate::rpc::subscribe_state_cache_to_pubsub(
86            Arc::clone(&state_cache),
87            pubsub_client.receiver,
88        );
89        let isolated_runtime =
90            IsolatedRuntime::new(registry, Some(cfg.isolated_runtime_worker_threads));
91
92        PersistClientCache {
93            cfg,
94            metrics,
95            blob_by_uri: Mutex::new(BTreeMap::new()),
96            consensus_by_uri: Mutex::new(BTreeMap::new()),
97            isolated_runtime: Arc::new(isolated_runtime),
98            state_cache,
99            pubsub_sender: pubsub_client.sender,
100            _pubsub_receiver_task,
101        }
102    }
103
104    /// A test helper that returns a [PersistClientCache] disconnected from
105    /// metrics.
106    pub fn new_no_metrics() -> Self {
107        Self::new(
108            PersistConfig::new_for_tests(),
109            &MetricsRegistry::new(),
110            |_, _| PubSubClientConnection::noop(),
111        )
112    }
113
114    /// Returns the [PersistConfig] being used by this cache.
115    pub fn cfg(&self) -> &PersistConfig {
116        &self.cfg
117    }
118
119    /// Returns persist `Metrics`.
120    pub fn metrics(&self) -> &Arc<Metrics> {
121        &self.metrics
122    }
123
124    /// Returns `ShardMetrics` for the given shard.
125    pub fn shard_metrics(&self, shard_id: &ShardId, name: &str) -> Arc<ShardMetrics> {
126        self.metrics.shards.shard(shard_id, name)
127    }
128
129    /// Clears the state cache, allowing for tests with disconnected states.
130    ///
131    /// Only exposed for testing.
132    pub fn clear_state_cache(&mut self) {
133        self.state_cache = Arc::new(StateCache::new(
134            &self.cfg,
135            Arc::clone(&self.metrics),
136            Arc::clone(&self.pubsub_sender),
137        ))
138    }
139
140    /// Returns a new [PersistClient] for interfacing with persist shards made
141    /// durable to the given [PersistLocation].
142    ///
143    /// The same `location` may be used concurrently from multiple processes.
144    #[instrument(level = "debug")]
145    pub async fn open(&self, location: PersistLocation) -> Result<PersistClient, ExternalError> {
146        let blob = self.open_blob(location.blob_uri).await?;
147        let consensus = self.open_consensus(location.consensus_uri).await?;
148        PersistClient::new(
149            self.cfg.clone(),
150            blob,
151            consensus,
152            Arc::clone(&self.metrics),
153            Arc::clone(&self.isolated_runtime),
154            Arc::clone(&self.state_cache),
155            Arc::clone(&self.pubsub_sender),
156        )
157    }
158
159    // No sense in measuring rtt latencies more often than this.
160    const PROMETHEUS_SCRAPE_INTERVAL: Duration = Duration::from_secs(60);
161
162    async fn open_consensus(
163        &self,
164        consensus_uri: SensitiveUrl,
165    ) -> Result<Arc<dyn Consensus>, ExternalError> {
166        let mut consensus_by_uri = self.consensus_by_uri.lock().await;
167        let consensus = match consensus_by_uri.entry(consensus_uri) {
168            Entry::Occupied(x) => Arc::clone(&x.get().1),
169            Entry::Vacant(x) => {
170                // Intentionally hold the lock, so we don't double connect under
171                // concurrency.
172                let consensus = ConsensusConfig::try_from(
173                    x.key(),
174                    Box::new(self.cfg.clone()),
175                    self.metrics.postgres_consensus.clone(),
176                )?;
177                let consensus =
178                    retry_external(&self.metrics.retries.external.consensus_open, || {
179                        consensus.clone().open()
180                    })
181                    .await;
182                let consensus =
183                    Arc::new(MetricsConsensus::new(consensus, Arc::clone(&self.metrics)));
184                let consensus = Arc::new(Tasked(consensus));
185                let task = consensus_rtt_latency_task(
186                    Arc::clone(&consensus),
187                    Arc::clone(&self.metrics),
188                    Self::PROMETHEUS_SCRAPE_INTERVAL,
189                )
190                .await;
191                Arc::clone(
192                    &x.insert((RttLatencyTask(task.abort_on_drop()), consensus))
193                        .1,
194                )
195            }
196        };
197        Ok(consensus)
198    }
199
200    async fn open_blob(&self, blob_uri: SensitiveUrl) -> Result<Arc<dyn Blob>, ExternalError> {
201        let mut blob_by_uri = self.blob_by_uri.lock().await;
202        let blob = match blob_by_uri.entry(blob_uri) {
203            Entry::Occupied(x) => Arc::clone(&x.get().1),
204            Entry::Vacant(x) => {
205                // Intentionally hold the lock, so we don't double connect under
206                // concurrency.
207                let blob = BlobConfig::try_from(
208                    x.key(),
209                    Box::new(self.cfg.clone()),
210                    self.metrics.s3_blob.clone(),
211                    Arc::clone(&self.cfg.configs),
212                )
213                .await?;
214                let blob = retry_external(&self.metrics.retries.external.blob_open, || {
215                    blob.clone().open()
216                })
217                .await;
218                let blob = Arc::new(MetricsBlob::new(blob, Arc::clone(&self.metrics)));
219                let blob = Arc::new(Tasked(blob));
220                let task = blob_rtt_latency_task(
221                    Arc::clone(&blob),
222                    Arc::clone(&self.metrics),
223                    Self::PROMETHEUS_SCRAPE_INTERVAL,
224                )
225                .await;
226                // This is intentionally "outside" (wrapping) MetricsBlob so
227                // that we don't include cached responses in blob metrics.
228                let blob = BlobMemCache::new(&self.cfg, Arc::clone(&self.metrics), blob);
229                Arc::clone(&x.insert((RttLatencyTask(task.abort_on_drop()), blob)).1)
230            }
231        };
232        Ok(blob)
233    }
234}
235
236/// Starts a task to periodically measure the persist-observed latency to
237/// consensus.
238///
239/// This is a task, rather than something like looking at the latencies of prod
240/// traffic, so that we minimize any issues around Futures not being polled
241/// promptly (as can and does happen with the Timely-polled Futures).
242///
243/// The caller is responsible for shutdown via aborting the `JoinHandle`.
244///
245/// No matter whether we wrap MetricsConsensus before or after we start up the
246/// rtt latency task, there's the possibility for it being confusing at some
247/// point. Err on the side of more data (including the latency measurements) to
248/// start.
249#[allow(clippy::unused_async)]
250async fn blob_rtt_latency_task(
251    blob: Arc<Tasked<MetricsBlob>>,
252    metrics: Arc<Metrics>,
253    measurement_interval: Duration,
254) -> JoinHandle<()> {
255    mz_ore::task::spawn(|| "persist::blob_rtt_latency", async move {
256        // Use the tokio Instant for next_measurement because the reclock tests
257        // mess with the tokio sleep clock.
258        let mut next_measurement = tokio::time::Instant::now();
259        loop {
260            tokio::time::sleep_until(next_measurement).await;
261            let start = Instant::now();
262            match blob.get(BLOB_GET_LIVENESS_KEY).await {
263                Ok(_) => {
264                    metrics.blob.rtt_latency.set(start.elapsed().as_secs_f64());
265                }
266                Err(_) => {
267                    // Don't spam retries if this returns an error. We're
268                    // guaranteed by the method signature that we've already got
269                    // metrics coverage of these, so we'll count the errors.
270                }
271            }
272            next_measurement = tokio::time::Instant::now() + measurement_interval;
273        }
274    })
275}
276
277/// Starts a task to periodically measure the persist-observed latency to
278/// consensus.
279///
280/// This is a task, rather than something like looking at the latencies of prod
281/// traffic, so that we minimize any issues around Futures not being polled
282/// promptly (as can and does happen with the Timely-polled Futures).
283///
284/// The caller is responsible for shutdown via aborting the `JoinHandle`.
285///
286/// No matter whether we wrap MetricsConsensus before or after we start up the
287/// rtt latency task, there's the possibility for it being confusing at some
288/// point. Err on the side of more data (including the latency measurements) to
289/// start.
290#[allow(clippy::unused_async)]
291async fn consensus_rtt_latency_task(
292    consensus: Arc<Tasked<MetricsConsensus>>,
293    metrics: Arc<Metrics>,
294    measurement_interval: Duration,
295) -> JoinHandle<()> {
296    mz_ore::task::spawn(|| "persist::consensus_rtt_latency", async move {
297        // Use the tokio Instant for next_measurement because the reclock tests
298        // mess with the tokio sleep clock.
299        let mut next_measurement = tokio::time::Instant::now();
300        loop {
301            tokio::time::sleep_until(next_measurement).await;
302            let start = Instant::now();
303            match consensus.head(CONSENSUS_HEAD_LIVENESS_KEY).await {
304                Ok(_) => {
305                    metrics
306                        .consensus
307                        .rtt_latency
308                        .set(start.elapsed().as_secs_f64());
309                }
310                Err(_) => {
311                    // Don't spam retries if this returns an error. We're
312                    // guaranteed by the method signature that we've already got
313                    // metrics coverage of these, so we'll count the errors.
314                }
315            }
316            next_measurement = tokio::time::Instant::now() + measurement_interval;
317        }
318    })
319}
320
321pub(crate) trait DynState: Debug + Send + Sync {
322    fn codecs(&self) -> (String, String, String, String, Option<CodecConcreteType>);
323    fn as_any(self: Arc<Self>) -> Arc<dyn Any + Send + Sync>;
324    fn push_diff(&self, diff: VersionedData);
325}
326
327impl<K, V, T, D> DynState for LockingTypedState<K, V, T, D>
328where
329    K: Codec,
330    V: Codec,
331    T: Timestamp + Lattice + Codec64 + Sync,
332    D: Codec64,
333{
334    fn codecs(&self) -> (String, String, String, String, Option<CodecConcreteType>) {
335        (
336            K::codec_name(),
337            V::codec_name(),
338            T::codec_name(),
339            D::codec_name(),
340            Some(CodecConcreteType(std::any::type_name::<(K, V, T, D)>())),
341        )
342    }
343
344    fn as_any(self: Arc<Self>) -> Arc<dyn Any + Send + Sync> {
345        self
346    }
347
348    fn push_diff(&self, diff: VersionedData) {
349        self.write_lock(&self.metrics.locks.applier_write, |state| {
350            let seqno_before = state.seqno;
351            state.apply_encoded_diffs(&self.cfg, &self.metrics, std::iter::once(&diff));
352            let seqno_after = state.seqno;
353            assert!(seqno_after >= seqno_before);
354
355            if seqno_before != seqno_after {
356                debug!(
357                    "applied pushed diff {}. seqno {} -> {}.",
358                    state.shard_id, seqno_before, state.seqno
359                );
360                self.shard_metrics.pubsub_push_diff_applied.inc();
361            } else {
362                debug!(
363                    "failed to apply pushed diff {}. seqno {} vs diff {}",
364                    state.shard_id, seqno_before, diff.seqno
365                );
366                if diff.seqno <= seqno_before {
367                    self.shard_metrics.pubsub_push_diff_not_applied_stale.inc();
368                } else {
369                    self.shard_metrics
370                        .pubsub_push_diff_not_applied_out_of_order
371                        .inc();
372                }
373            }
374        })
375    }
376}
377
378/// A cache of `TypedState`, shared between all machines for that shard.
379///
380/// This is shared between all machines that come out of the same
381/// [PersistClientCache], but in production there is one of those per process,
382/// so in practice, we have one copy of state per shard per process.
383///
384/// The mutex contention between commands is not an issue, because if two
385/// command for the same shard are executing concurrently, only one can win
386/// anyway, the other will retry. With the mutex, we even get to avoid the retry
387/// if the racing commands are on the same process.
388#[derive(Debug)]
389pub struct StateCache {
390    cfg: Arc<PersistConfig>,
391    pub(crate) metrics: Arc<Metrics>,
392    states: Arc<std::sync::Mutex<BTreeMap<ShardId, Arc<OnceCell<Weak<dyn DynState>>>>>>,
393    pubsub_sender: Arc<dyn PubSubSender>,
394}
395
396#[derive(Debug)]
397enum StateCacheInit {
398    Init(Arc<dyn DynState>),
399    NeedInit(Arc<OnceCell<Weak<dyn DynState>>>),
400}
401
402impl StateCache {
403    /// Returns a new StateCache.
404    pub fn new(
405        cfg: &PersistConfig,
406        metrics: Arc<Metrics>,
407        pubsub_sender: Arc<dyn PubSubSender>,
408    ) -> Self {
409        StateCache {
410            cfg: Arc::new(cfg.clone()),
411            metrics,
412            states: Default::default(),
413            pubsub_sender,
414        }
415    }
416
417    #[cfg(test)]
418    pub(crate) fn new_no_metrics() -> Self {
419        Self::new(
420            &PersistConfig::new_for_tests(),
421            Arc::new(Metrics::new(
422                &PersistConfig::new_for_tests(),
423                &MetricsRegistry::new(),
424            )),
425            Arc::new(crate::rpc::NoopPubSubSender),
426        )
427    }
428
429    pub(crate) async fn get<K, V, T, D, F, InitFn>(
430        &self,
431        shard_id: ShardId,
432        mut init_fn: InitFn,
433        diagnostics: &Diagnostics,
434    ) -> Result<Arc<LockingTypedState<K, V, T, D>>, Box<CodecMismatch>>
435    where
436        K: Debug + Codec,
437        V: Debug + Codec,
438        T: Timestamp + Lattice + Codec64 + Sync,
439        D: Semigroup + Codec64,
440        F: Future<Output = Result<TypedState<K, V, T, D>, Box<CodecMismatch>>>,
441        InitFn: FnMut() -> F,
442    {
443        loop {
444            let init = {
445                let mut states = self.states.lock().expect("lock poisoned");
446                let state = states.entry(shard_id).or_default();
447                match state.get() {
448                    Some(once_val) => match once_val.upgrade() {
449                        Some(x) => StateCacheInit::Init(x),
450                        None => {
451                            // If the Weak has lost the ability to upgrade,
452                            // we've dropped the State and it's gone. Clear the
453                            // OnceCell and init a new one.
454                            *state = Arc::new(OnceCell::new());
455                            StateCacheInit::NeedInit(Arc::clone(state))
456                        }
457                    },
458                    None => StateCacheInit::NeedInit(Arc::clone(state)),
459                }
460            };
461
462            let state = match init {
463                StateCacheInit::Init(x) => x,
464                StateCacheInit::NeedInit(init_once) => {
465                    let mut did_init: Option<Arc<LockingTypedState<K, V, T, D>>> = None;
466                    let state = init_once
467                        .get_or_try_init::<Box<CodecMismatch>, _, _>(|| async {
468                            let init_res = init_fn().await;
469                            let state = Arc::new(LockingTypedState::new(
470                                shard_id,
471                                init_res?,
472                                Arc::clone(&self.metrics),
473                                Arc::clone(&self.cfg),
474                                Arc::clone(&self.pubsub_sender).subscribe(&shard_id),
475                                diagnostics,
476                            ));
477                            let ret = Arc::downgrade(&state);
478                            did_init = Some(state);
479                            let ret: Weak<dyn DynState> = ret;
480                            Ok(ret)
481                        })
482                        .await?;
483                    if let Some(x) = did_init {
484                        // We actually did the init work, don't bother casting back
485                        // the type erased and weak version. Additionally, inform
486                        // any listeners of this new state.
487                        return Ok(x);
488                    }
489                    let Some(state) = state.upgrade() else {
490                        // Race condition. Between when we first checked the
491                        // OnceCell and the `get_or_try_init` call, (1) the
492                        // initialization finished, (2) the other user dropped
493                        // the strong ref, and (3) the Arc noticed it was down
494                        // to only weak refs and dropped the value. Nothing we
495                        // can do except try again.
496                        continue;
497                    };
498                    state
499                }
500            };
501
502            match Arc::clone(&state)
503                .as_any()
504                .downcast::<LockingTypedState<K, V, T, D>>()
505            {
506                Ok(x) => return Ok(x),
507                Err(_) => {
508                    return Err(Box::new(CodecMismatch {
509                        requested: (
510                            K::codec_name(),
511                            V::codec_name(),
512                            T::codec_name(),
513                            D::codec_name(),
514                            Some(CodecConcreteType(std::any::type_name::<(K, V, T, D)>())),
515                        ),
516                        actual: state.codecs(),
517                    }));
518                }
519            }
520        }
521    }
522
523    pub(crate) fn get_state_weak(&self, shard_id: &ShardId) -> Option<Weak<dyn DynState>> {
524        self.states
525            .lock()
526            .expect("lock")
527            .get(shard_id)
528            .and_then(|x| x.get())
529            .map(Weak::clone)
530    }
531
532    #[cfg(test)]
533    fn get_cached(&self, shard_id: &ShardId) -> Option<Arc<dyn DynState>> {
534        self.states
535            .lock()
536            .expect("lock")
537            .get(shard_id)
538            .and_then(|x| x.get())
539            .and_then(|x| x.upgrade())
540    }
541
542    #[cfg(test)]
543    fn initialized_count(&self) -> usize {
544        self.states
545            .lock()
546            .expect("lock")
547            .values()
548            .filter(|x| x.initialized())
549            .count()
550    }
551
552    #[cfg(test)]
553    fn strong_count(&self) -> usize {
554        self.states
555            .lock()
556            .expect("lock")
557            .values()
558            .filter(|x| x.get().map_or(false, |x| x.upgrade().is_some()))
559            .count()
560    }
561}
562
563/// A locked decorator for TypedState that abstracts out the specific lock implementation used.
564/// Guards the private lock with public accessor fns to make locking scopes more explicit and
565/// simpler to reason about.
566pub(crate) struct LockingTypedState<K, V, T, D> {
567    shard_id: ShardId,
568    state: RwLock<TypedState<K, V, T, D>>,
569    notifier: StateWatchNotifier,
570    cfg: Arc<PersistConfig>,
571    metrics: Arc<Metrics>,
572    shard_metrics: Arc<ShardMetrics>,
573    /// A [SchemaCacheMaps<K, V>], but stored as an Any so the `: Codec` bounds
574    /// don't propagate to basically every struct in persist.
575    schema_cache: Arc<dyn Any + Send + Sync>,
576    _subscription_token: Arc<ShardSubscriptionToken>,
577}
578
579impl<K, V, T: Debug, D> Debug for LockingTypedState<K, V, T, D> {
580    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
581        let LockingTypedState {
582            shard_id,
583            state,
584            notifier,
585            cfg: _cfg,
586            metrics: _metrics,
587            shard_metrics: _shard_metrics,
588            schema_cache: _schema_cache,
589            _subscription_token,
590        } = self;
591        f.debug_struct("LockingTypedState")
592            .field("shard_id", shard_id)
593            .field("state", state)
594            .field("notifier", notifier)
595            .finish()
596    }
597}
598
599impl<K: Codec, V: Codec, T, D> LockingTypedState<K, V, T, D> {
600    fn new(
601        shard_id: ShardId,
602        initial_state: TypedState<K, V, T, D>,
603        metrics: Arc<Metrics>,
604        cfg: Arc<PersistConfig>,
605        subscription_token: Arc<ShardSubscriptionToken>,
606        diagnostics: &Diagnostics,
607    ) -> Self {
608        Self {
609            shard_id,
610            notifier: StateWatchNotifier::new(Arc::clone(&metrics)),
611            state: RwLock::new(initial_state),
612            cfg: Arc::clone(&cfg),
613            shard_metrics: metrics.shards.shard(&shard_id, &diagnostics.shard_name),
614            schema_cache: Arc::new(SchemaCacheMaps::<K, V>::new(&metrics.schema)),
615            metrics,
616            _subscription_token: subscription_token,
617        }
618    }
619
620    pub(crate) fn schema_cache(&self) -> Arc<SchemaCacheMaps<K, V>> {
621        Arc::clone(&self.schema_cache)
622            .downcast::<SchemaCacheMaps<K, V>>()
623            .expect("K and V match")
624    }
625}
626
627impl<K, V, T, D> LockingTypedState<K, V, T, D> {
628    pub(crate) fn shard_id(&self) -> &ShardId {
629        &self.shard_id
630    }
631
632    pub(crate) fn read_lock<R, F: FnMut(&TypedState<K, V, T, D>) -> R>(
633        &self,
634        metrics: &LockMetrics,
635        mut f: F,
636    ) -> R {
637        metrics.acquire_count.inc();
638        let state = match self.state.try_read() {
639            Ok(x) => x,
640            Err(TryLockError::WouldBlock) => {
641                metrics.blocking_acquire_count.inc();
642                let start = Instant::now();
643                let state = self.state.read().expect("lock poisoned");
644                metrics
645                    .blocking_seconds
646                    .inc_by(start.elapsed().as_secs_f64());
647                state
648            }
649            Err(TryLockError::Poisoned(err)) => panic!("state read lock poisoned: {}", err),
650        };
651        f(&state)
652    }
653
654    pub(crate) fn write_lock<R, F: FnOnce(&mut TypedState<K, V, T, D>) -> R>(
655        &self,
656        metrics: &LockMetrics,
657        f: F,
658    ) -> R {
659        metrics.acquire_count.inc();
660        let mut state = match self.state.try_write() {
661            Ok(x) => x,
662            Err(TryLockError::WouldBlock) => {
663                metrics.blocking_acquire_count.inc();
664                let start = Instant::now();
665                let state = self.state.write().expect("lock poisoned");
666                metrics
667                    .blocking_seconds
668                    .inc_by(start.elapsed().as_secs_f64());
669                state
670            }
671            Err(TryLockError::Poisoned(err)) => panic!("state read lock poisoned: {}", err),
672        };
673        let seqno_before = state.seqno;
674        let ret = f(&mut state);
675        let seqno_after = state.seqno;
676        debug_assert!(seqno_after >= seqno_before);
677        if seqno_after > seqno_before {
678            self.notifier.notify(seqno_after);
679        }
680        // For now, make sure to notify while under lock. It's possible to move
681        // this out of the lock window, see [StateWatchNotifier::notify].
682        drop(state);
683        ret
684    }
685
686    pub(crate) fn notifier(&self) -> &StateWatchNotifier {
687        &self.notifier
688    }
689}
690
691#[cfg(test)]
692mod tests {
693    use std::ops::Deref;
694    use std::str::FromStr;
695    use std::sync::atomic::{AtomicBool, Ordering};
696
697    use futures::stream::{FuturesUnordered, StreamExt};
698    use mz_build_info::DUMMY_BUILD_INFO;
699    use mz_ore::task::spawn;
700    use mz_ore::{assert_err, assert_none};
701
702    use super::*;
703
704    #[mz_ore::test(tokio::test)]
705    #[cfg_attr(miri, ignore)] // unsupported operation: returning ready events from epoll_wait is not yet implemented
706    async fn client_cache() {
707        let cache = PersistClientCache::new(
708            PersistConfig::new_for_tests(),
709            &MetricsRegistry::new(),
710            |_, _| PubSubClientConnection::noop(),
711        );
712        assert_eq!(cache.blob_by_uri.lock().await.len(), 0);
713        assert_eq!(cache.consensus_by_uri.lock().await.len(), 0);
714
715        // Opening a location on an empty cache saves the results.
716        let _ = cache
717            .open(PersistLocation {
718                blob_uri: SensitiveUrl::from_str("mem://blob_zero").expect("invalid URL"),
719                consensus_uri: SensitiveUrl::from_str("mem://consensus_zero").expect("invalid URL"),
720            })
721            .await
722            .expect("failed to open location");
723        assert_eq!(cache.blob_by_uri.lock().await.len(), 1);
724        assert_eq!(cache.consensus_by_uri.lock().await.len(), 1);
725
726        // Opening a location with an already opened consensus reuses it, even
727        // if the blob is different.
728        let _ = cache
729            .open(PersistLocation {
730                blob_uri: SensitiveUrl::from_str("mem://blob_one").expect("invalid URL"),
731                consensus_uri: SensitiveUrl::from_str("mem://consensus_zero").expect("invalid URL"),
732            })
733            .await
734            .expect("failed to open location");
735        assert_eq!(cache.blob_by_uri.lock().await.len(), 2);
736        assert_eq!(cache.consensus_by_uri.lock().await.len(), 1);
737
738        // Ditto the other way.
739        let _ = cache
740            .open(PersistLocation {
741                blob_uri: SensitiveUrl::from_str("mem://blob_one").expect("invalid URL"),
742                consensus_uri: SensitiveUrl::from_str("mem://consensus_one").expect("invalid URL"),
743            })
744            .await
745            .expect("failed to open location");
746        assert_eq!(cache.blob_by_uri.lock().await.len(), 2);
747        assert_eq!(cache.consensus_by_uri.lock().await.len(), 2);
748
749        // Query params and path matter, so we get new instances.
750        let _ = cache
751            .open(PersistLocation {
752                blob_uri: SensitiveUrl::from_str("mem://blob_one?foo").expect("invalid URL"),
753                consensus_uri: SensitiveUrl::from_str("mem://consensus_one/bar")
754                    .expect("invalid URL"),
755            })
756            .await
757            .expect("failed to open location");
758        assert_eq!(cache.blob_by_uri.lock().await.len(), 3);
759        assert_eq!(cache.consensus_by_uri.lock().await.len(), 3);
760
761        // User info and port also matter, so we get new instances.
762        let _ = cache
763            .open(PersistLocation {
764                blob_uri: SensitiveUrl::from_str("mem://user@blob_one").expect("invalid URL"),
765                consensus_uri: SensitiveUrl::from_str("mem://@consensus_one:123")
766                    .expect("invalid URL"),
767            })
768            .await
769            .expect("failed to open location");
770        assert_eq!(cache.blob_by_uri.lock().await.len(), 4);
771        assert_eq!(cache.consensus_by_uri.lock().await.len(), 4);
772    }
773
774    #[mz_ore::test(tokio::test)]
775    #[cfg_attr(miri, ignore)] // unsupported operation: returning ready events from epoll_wait is not yet implemented
776    async fn state_cache() {
777        mz_ore::test::init_logging();
778        fn new_state<K, V, T, D>(shard_id: ShardId) -> TypedState<K, V, T, D>
779        where
780            K: Codec,
781            V: Codec,
782            T: Timestamp + Lattice + Codec64,
783            D: Codec64,
784        {
785            TypedState::new(
786                DUMMY_BUILD_INFO.semver_version(),
787                shard_id,
788                "host".into(),
789                0,
790            )
791        }
792        fn assert_same<K, V, T, D>(
793            state1: &LockingTypedState<K, V, T, D>,
794            state2: &LockingTypedState<K, V, T, D>,
795        ) {
796            let pointer1 = format!("{:p}", state1.state.read().expect("lock").deref());
797            let pointer2 = format!("{:p}", state2.state.read().expect("lock").deref());
798            assert_eq!(pointer1, pointer2);
799        }
800
801        let s1 = ShardId::new();
802        let states = Arc::new(StateCache::new_no_metrics());
803
804        // The cache starts empty.
805        assert_eq!(states.states.lock().expect("lock").len(), 0);
806
807        // Panic'ing during init_fn .
808        let s = Arc::clone(&states);
809        let res = spawn(|| "test", async move {
810            s.get::<(), (), u64, i64, _, _>(
811                s1,
812                || async { panic!("forced panic") },
813                &Diagnostics::for_tests(),
814            )
815            .await
816        })
817        .await;
818        assert_err!(res);
819        assert_eq!(states.initialized_count(), 0);
820
821        // Returning an error from init_fn doesn't initialize an entry in the cache.
822        let res = states
823            .get::<(), (), u64, i64, _, _>(
824                s1,
825                || async {
826                    Err(Box::new(CodecMismatch {
827                        requested: ("".into(), "".into(), "".into(), "".into(), None),
828                        actual: ("".into(), "".into(), "".into(), "".into(), None),
829                    }))
830                },
831                &Diagnostics::for_tests(),
832            )
833            .await;
834        assert_err!(res);
835        assert_eq!(states.initialized_count(), 0);
836
837        // Initialize one shard.
838        let did_work = Arc::new(AtomicBool::new(false));
839        let s1_state1 = states
840            .get::<(), (), u64, i64, _, _>(
841                s1,
842                || {
843                    let did_work = Arc::clone(&did_work);
844                    async move {
845                        did_work.store(true, Ordering::SeqCst);
846                        Ok(new_state(s1))
847                    }
848                },
849                &Diagnostics::for_tests(),
850            )
851            .await
852            .expect("should successfully initialize");
853        assert_eq!(did_work.load(Ordering::SeqCst), true);
854        assert_eq!(states.initialized_count(), 1);
855        assert_eq!(states.strong_count(), 1);
856
857        // Trying to initialize it again does no work and returns the same state.
858        let did_work = Arc::new(AtomicBool::new(false));
859        let s1_state2 = states
860            .get::<(), (), u64, i64, _, _>(
861                s1,
862                || {
863                    let did_work = Arc::clone(&did_work);
864                    async move {
865                        did_work.store(true, Ordering::SeqCst);
866                        did_work.store(true, Ordering::SeqCst);
867                        Ok(new_state(s1))
868                    }
869                },
870                &Diagnostics::for_tests(),
871            )
872            .await
873            .expect("should successfully initialize");
874        assert_eq!(did_work.load(Ordering::SeqCst), false);
875        assert_eq!(states.initialized_count(), 1);
876        assert_eq!(states.strong_count(), 1);
877        assert_same(&s1_state1, &s1_state2);
878
879        // Trying to initialize with different types doesn't work.
880        let did_work = Arc::new(AtomicBool::new(false));
881        let res = states
882            .get::<String, (), u64, i64, _, _>(
883                s1,
884                || {
885                    let did_work = Arc::clone(&did_work);
886                    async move {
887                        did_work.store(true, Ordering::SeqCst);
888                        Ok(new_state(s1))
889                    }
890                },
891                &Diagnostics::for_tests(),
892            )
893            .await;
894        assert_eq!(did_work.load(Ordering::SeqCst), false);
895        assert_eq!(
896            format!("{}", res.expect_err("types shouldn't match")),
897            "requested codecs (\"String\", \"()\", \"u64\", \"i64\", Some(CodecConcreteType(\"(alloc::string::String, (), u64, i64)\"))) did not match ones in durable storage (\"()\", \"()\", \"u64\", \"i64\", Some(CodecConcreteType(\"((), (), u64, i64)\")))"
898        );
899        assert_eq!(states.initialized_count(), 1);
900        assert_eq!(states.strong_count(), 1);
901
902        // We can add a shard of a different type.
903        let s2 = ShardId::new();
904        let s2_state1 = states
905            .get::<String, (), u64, i64, _, _>(
906                s2,
907                || async { Ok(new_state(s2)) },
908                &Diagnostics::for_tests(),
909            )
910            .await
911            .expect("should successfully initialize");
912        assert_eq!(states.initialized_count(), 2);
913        assert_eq!(states.strong_count(), 2);
914        let s2_state2 = states
915            .get::<String, (), u64, i64, _, _>(
916                s2,
917                || async { Ok(new_state(s2)) },
918                &Diagnostics::for_tests(),
919            )
920            .await
921            .expect("should successfully initialize");
922        assert_same(&s2_state1, &s2_state2);
923
924        // The cache holds weak references to State so we reclaim memory if the
925        // shards stops being used.
926        drop(s1_state1);
927        assert_eq!(states.strong_count(), 2);
928        drop(s1_state2);
929        assert_eq!(states.strong_count(), 1);
930        assert_eq!(states.initialized_count(), 2);
931        assert_none!(states.get_cached(&s1));
932
933        // But we can re-init that shard if necessary.
934        let s1_state1 = states
935            .get::<(), (), u64, i64, _, _>(
936                s1,
937                || async { Ok(new_state(s1)) },
938                &Diagnostics::for_tests(),
939            )
940            .await
941            .expect("should successfully initialize");
942        assert_eq!(states.initialized_count(), 2);
943        assert_eq!(states.strong_count(), 2);
944        drop(s1_state1);
945        assert_eq!(states.strong_count(), 1);
946    }
947
948    #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
949    #[cfg_attr(miri, ignore)] // too slow
950    async fn state_cache_concurrency() {
951        mz_ore::test::init_logging();
952
953        const COUNT: usize = 1000;
954        let id = ShardId::new();
955        let cache = StateCache::new_no_metrics();
956        let diagnostics = Diagnostics::for_tests();
957
958        let mut futures = (0..COUNT)
959            .map(|_| {
960                cache.get::<(), (), u64, i64, _, _>(
961                    id,
962                    || async {
963                        Ok(TypedState::new(
964                            DUMMY_BUILD_INFO.semver_version(),
965                            id,
966                            "host".into(),
967                            0,
968                        ))
969                    },
970                    &diagnostics,
971                )
972            })
973            .collect::<FuturesUnordered<_>>();
974
975        for _ in 0..COUNT {
976            let _ = futures.next().await.unwrap();
977        }
978    }
979}