mz_persist_client/
cache.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! A cache of [PersistClient]s indexed by [PersistLocation]s.
11
12use std::any::Any;
13use std::collections::BTreeMap;
14use std::collections::btree_map::Entry;
15use std::fmt::Debug;
16use std::future::Future;
17use std::sync::{Arc, RwLock, TryLockError, Weak};
18use std::time::{Duration, Instant};
19
20use differential_dataflow::difference::Monoid;
21use differential_dataflow::lattice::Lattice;
22use mz_ore::instrument;
23use mz_ore::metrics::MetricsRegistry;
24use mz_ore::task::{AbortOnDropHandle, JoinHandle};
25use mz_ore::url::SensitiveUrl;
26use mz_persist::cfg::{BlobConfig, ConsensusConfig};
27use mz_persist::location::{
28    BLOB_GET_LIVENESS_KEY, Blob, CONSENSUS_HEAD_LIVENESS_KEY, Consensus, ExternalError, Tasked,
29    VersionedData,
30};
31use mz_persist_types::{Codec, Codec64};
32use timely::progress::Timestamp;
33use tokio::sync::{Mutex, OnceCell};
34use tracing::debug;
35
36use crate::async_runtime::IsolatedRuntime;
37use crate::error::{CodecConcreteType, CodecMismatch};
38use crate::internal::cache::BlobMemCache;
39use crate::internal::machine::retry_external;
40use crate::internal::metrics::{LockMetrics, Metrics, MetricsBlob, MetricsConsensus, ShardMetrics};
41use crate::internal::state::TypedState;
42use crate::internal::watch::StateWatchNotifier;
43use crate::rpc::{PubSubClientConnection, PubSubSender, ShardSubscriptionToken};
44use crate::schema::SchemaCacheMaps;
45use crate::{Diagnostics, PersistClient, PersistConfig, PersistLocation, ShardId};
46
47/// A cache of [PersistClient]s indexed by [PersistLocation]s.
48///
49/// There should be at most one of these per process. All production
50/// PersistClients should be created through this cache.
51///
52/// This is because, in production, persist is heavily limited by the number of
53/// server-side Postgres/Aurora connections. This cache allows PersistClients to
54/// share, for example, these Postgres connections.
55#[derive(Debug)]
56pub struct PersistClientCache {
57    /// The tunable knobs for persist.
58    pub cfg: PersistConfig,
59    pub(crate) metrics: Arc<Metrics>,
60    blob_by_uri: Mutex<BTreeMap<SensitiveUrl, (RttLatencyTask, Arc<dyn Blob>)>>,
61    consensus_by_uri: Mutex<BTreeMap<SensitiveUrl, (RttLatencyTask, Arc<dyn Consensus>)>>,
62    isolated_runtime: Arc<IsolatedRuntime>,
63    pub(crate) state_cache: Arc<StateCache>,
64    pubsub_sender: Arc<dyn PubSubSender>,
65    _pubsub_receiver_task: JoinHandle<()>,
66}
67
68#[derive(Debug)]
69struct RttLatencyTask(#[allow(dead_code)] AbortOnDropHandle<()>);
70
71impl PersistClientCache {
72    /// Returns a new [PersistClientCache].
73    pub fn new<F>(cfg: PersistConfig, registry: &MetricsRegistry, pubsub: F) -> Self
74    where
75        F: FnOnce(&PersistConfig, Arc<Metrics>) -> PubSubClientConnection,
76    {
77        let metrics = Arc::new(Metrics::new(&cfg, registry));
78        let pubsub_client = pubsub(&cfg, Arc::clone(&metrics));
79
80        let state_cache = Arc::new(StateCache::new(
81            &cfg,
82            Arc::clone(&metrics),
83            Arc::clone(&pubsub_client.sender),
84        ));
85        let _pubsub_receiver_task = crate::rpc::subscribe_state_cache_to_pubsub(
86            Arc::clone(&state_cache),
87            pubsub_client.receiver,
88        );
89        let isolated_runtime =
90            IsolatedRuntime::new(registry, Some(cfg.isolated_runtime_worker_threads));
91
92        PersistClientCache {
93            cfg,
94            metrics,
95            blob_by_uri: Mutex::new(BTreeMap::new()),
96            consensus_by_uri: Mutex::new(BTreeMap::new()),
97            isolated_runtime: Arc::new(isolated_runtime),
98            state_cache,
99            pubsub_sender: pubsub_client.sender,
100            _pubsub_receiver_task,
101        }
102    }
103
104    /// A test helper that returns a [PersistClientCache] disconnected from
105    /// metrics.
106    pub fn new_no_metrics() -> Self {
107        Self::new(
108            PersistConfig::new_for_tests(),
109            &MetricsRegistry::new(),
110            |_, _| PubSubClientConnection::noop(),
111        )
112    }
113
114    #[cfg(feature = "turmoil")]
115    /// Create a [PersistClientCache] for use in turmoil tests.
116    ///
117    /// Turmoil wants to run all software under test in a single thread, so we disable the
118    /// (multi-threaded) isolated runtime.
119    pub fn new_for_turmoil() -> Self {
120        use crate::rpc::NoopPubSubSender;
121
122        let cfg = PersistConfig::new_for_tests();
123        let metrics = Arc::new(Metrics::new(&cfg, &MetricsRegistry::new()));
124
125        let pubsub_sender: Arc<dyn PubSubSender> = Arc::new(NoopPubSubSender);
126        let _pubsub_receiver_task = mz_ore::task::spawn(|| "noop", async {});
127
128        let state_cache = Arc::new(StateCache::new(
129            &cfg,
130            Arc::clone(&metrics),
131            Arc::clone(&pubsub_sender),
132        ));
133        let isolated_runtime = IsolatedRuntime::new_disabled();
134
135        PersistClientCache {
136            cfg,
137            metrics,
138            blob_by_uri: Mutex::new(BTreeMap::new()),
139            consensus_by_uri: Mutex::new(BTreeMap::new()),
140            isolated_runtime: Arc::new(isolated_runtime),
141            state_cache,
142            pubsub_sender,
143            _pubsub_receiver_task,
144        }
145    }
146
147    /// Returns the [PersistConfig] being used by this cache.
148    pub fn cfg(&self) -> &PersistConfig {
149        &self.cfg
150    }
151
152    /// Returns persist `Metrics`.
153    pub fn metrics(&self) -> &Arc<Metrics> {
154        &self.metrics
155    }
156
157    /// Returns `ShardMetrics` for the given shard.
158    pub fn shard_metrics(&self, shard_id: &ShardId, name: &str) -> Arc<ShardMetrics> {
159        self.metrics.shards.shard(shard_id, name)
160    }
161
162    /// Clears the state cache, allowing for tests with disconnected states.
163    ///
164    /// Only exposed for testing.
165    pub fn clear_state_cache(&mut self) {
166        self.state_cache = Arc::new(StateCache::new(
167            &self.cfg,
168            Arc::clone(&self.metrics),
169            Arc::clone(&self.pubsub_sender),
170        ))
171    }
172
173    /// Returns a new [PersistClient] for interfacing with persist shards made
174    /// durable to the given [PersistLocation].
175    ///
176    /// The same `location` may be used concurrently from multiple processes.
177    #[instrument(level = "debug")]
178    pub async fn open(&self, location: PersistLocation) -> Result<PersistClient, ExternalError> {
179        let blob = self.open_blob(location.blob_uri).await?;
180        let consensus = self.open_consensus(location.consensus_uri).await?;
181        PersistClient::new(
182            self.cfg.clone(),
183            blob,
184            consensus,
185            Arc::clone(&self.metrics),
186            Arc::clone(&self.isolated_runtime),
187            Arc::clone(&self.state_cache),
188            Arc::clone(&self.pubsub_sender),
189        )
190    }
191
192    // No sense in measuring rtt latencies more often than this.
193    const PROMETHEUS_SCRAPE_INTERVAL: Duration = Duration::from_secs(60);
194
195    async fn open_consensus(
196        &self,
197        consensus_uri: SensitiveUrl,
198    ) -> Result<Arc<dyn Consensus>, ExternalError> {
199        let mut consensus_by_uri = self.consensus_by_uri.lock().await;
200        let consensus = match consensus_by_uri.entry(consensus_uri) {
201            Entry::Occupied(x) => Arc::clone(&x.get().1),
202            Entry::Vacant(x) => {
203                // Intentionally hold the lock, so we don't double connect under
204                // concurrency.
205                let consensus = ConsensusConfig::try_from(
206                    x.key(),
207                    Box::new(self.cfg.clone()),
208                    self.metrics.postgres_consensus.clone(),
209                    Arc::clone(&self.cfg().configs),
210                )?;
211                let consensus =
212                    retry_external(&self.metrics.retries.external.consensus_open, || {
213                        consensus.clone().open()
214                    })
215                    .await;
216                let consensus =
217                    Arc::new(MetricsConsensus::new(consensus, Arc::clone(&self.metrics)));
218                let consensus = Arc::new(Tasked(consensus));
219                let task = consensus_rtt_latency_task(
220                    Arc::clone(&consensus),
221                    Arc::clone(&self.metrics),
222                    Self::PROMETHEUS_SCRAPE_INTERVAL,
223                )
224                .await;
225                Arc::clone(
226                    &x.insert((RttLatencyTask(task.abort_on_drop()), consensus))
227                        .1,
228                )
229            }
230        };
231        Ok(consensus)
232    }
233
234    async fn open_blob(&self, blob_uri: SensitiveUrl) -> Result<Arc<dyn Blob>, ExternalError> {
235        let mut blob_by_uri = self.blob_by_uri.lock().await;
236        let blob = match blob_by_uri.entry(blob_uri) {
237            Entry::Occupied(x) => Arc::clone(&x.get().1),
238            Entry::Vacant(x) => {
239                // Intentionally hold the lock, so we don't double connect under
240                // concurrency.
241                let blob = BlobConfig::try_from(
242                    x.key(),
243                    Box::new(self.cfg.clone()),
244                    self.metrics.s3_blob.clone(),
245                    Arc::clone(&self.cfg.configs),
246                )
247                .await?;
248                let blob = retry_external(&self.metrics.retries.external.blob_open, || {
249                    blob.clone().open()
250                })
251                .await;
252                let blob = Arc::new(MetricsBlob::new(blob, Arc::clone(&self.metrics)));
253                let blob = Arc::new(Tasked(blob));
254                let task = blob_rtt_latency_task(
255                    Arc::clone(&blob),
256                    Arc::clone(&self.metrics),
257                    Self::PROMETHEUS_SCRAPE_INTERVAL,
258                )
259                .await;
260                // This is intentionally "outside" (wrapping) MetricsBlob so
261                // that we don't include cached responses in blob metrics.
262                let blob = BlobMemCache::new(&self.cfg, Arc::clone(&self.metrics), blob);
263                Arc::clone(&x.insert((RttLatencyTask(task.abort_on_drop()), blob)).1)
264            }
265        };
266        Ok(blob)
267    }
268}
269
270/// Starts a task to periodically measure the persist-observed latency to
271/// consensus.
272///
273/// This is a task, rather than something like looking at the latencies of prod
274/// traffic, so that we minimize any issues around Futures not being polled
275/// promptly (as can and does happen with the Timely-polled Futures).
276///
277/// The caller is responsible for shutdown via aborting the `JoinHandle`.
278///
279/// No matter whether we wrap MetricsConsensus before or after we start up the
280/// rtt latency task, there's the possibility for it being confusing at some
281/// point. Err on the side of more data (including the latency measurements) to
282/// start.
283#[allow(clippy::unused_async)]
284async fn blob_rtt_latency_task(
285    blob: Arc<Tasked<MetricsBlob>>,
286    metrics: Arc<Metrics>,
287    measurement_interval: Duration,
288) -> JoinHandle<()> {
289    mz_ore::task::spawn(|| "persist::blob_rtt_latency", async move {
290        // Use the tokio Instant for next_measurement because the reclock tests
291        // mess with the tokio sleep clock.
292        let mut next_measurement = tokio::time::Instant::now();
293        loop {
294            tokio::time::sleep_until(next_measurement).await;
295            let start = Instant::now();
296            match blob.get(BLOB_GET_LIVENESS_KEY).await {
297                Ok(_) => {
298                    metrics.blob.rtt_latency.set(start.elapsed().as_secs_f64());
299                }
300                Err(_) => {
301                    // Don't spam retries if this returns an error. We're
302                    // guaranteed by the method signature that we've already got
303                    // metrics coverage of these, so we'll count the errors.
304                }
305            }
306            next_measurement = tokio::time::Instant::now() + measurement_interval;
307        }
308    })
309}
310
311/// Starts a task to periodically measure the persist-observed latency to
312/// consensus.
313///
314/// This is a task, rather than something like looking at the latencies of prod
315/// traffic, so that we minimize any issues around Futures not being polled
316/// promptly (as can and does happen with the Timely-polled Futures).
317///
318/// The caller is responsible for shutdown via aborting the `JoinHandle`.
319///
320/// No matter whether we wrap MetricsConsensus before or after we start up the
321/// rtt latency task, there's the possibility for it being confusing at some
322/// point. Err on the side of more data (including the latency measurements) to
323/// start.
324#[allow(clippy::unused_async)]
325async fn consensus_rtt_latency_task(
326    consensus: Arc<Tasked<MetricsConsensus>>,
327    metrics: Arc<Metrics>,
328    measurement_interval: Duration,
329) -> JoinHandle<()> {
330    mz_ore::task::spawn(|| "persist::consensus_rtt_latency", async move {
331        // Use the tokio Instant for next_measurement because the reclock tests
332        // mess with the tokio sleep clock.
333        let mut next_measurement = tokio::time::Instant::now();
334        loop {
335            tokio::time::sleep_until(next_measurement).await;
336            let start = Instant::now();
337            match consensus.head(CONSENSUS_HEAD_LIVENESS_KEY).await {
338                Ok(_) => {
339                    metrics
340                        .consensus
341                        .rtt_latency
342                        .set(start.elapsed().as_secs_f64());
343                }
344                Err(_) => {
345                    // Don't spam retries if this returns an error. We're
346                    // guaranteed by the method signature that we've already got
347                    // metrics coverage of these, so we'll count the errors.
348                }
349            }
350            next_measurement = tokio::time::Instant::now() + measurement_interval;
351        }
352    })
353}
354
355pub(crate) trait DynState: Debug + Send + Sync {
356    fn codecs(&self) -> (String, String, String, String, Option<CodecConcreteType>);
357    fn as_any(self: Arc<Self>) -> Arc<dyn Any + Send + Sync>;
358    fn push_diff(&self, diff: VersionedData);
359}
360
361impl<K, V, T, D> DynState for LockingTypedState<K, V, T, D>
362where
363    K: Codec,
364    V: Codec,
365    T: Timestamp + Lattice + Codec64 + Sync,
366    D: Codec64,
367{
368    fn codecs(&self) -> (String, String, String, String, Option<CodecConcreteType>) {
369        (
370            K::codec_name(),
371            V::codec_name(),
372            T::codec_name(),
373            D::codec_name(),
374            Some(CodecConcreteType(std::any::type_name::<(K, V, T, D)>())),
375        )
376    }
377
378    fn as_any(self: Arc<Self>) -> Arc<dyn Any + Send + Sync> {
379        self
380    }
381
382    fn push_diff(&self, diff: VersionedData) {
383        self.write_lock(&self.metrics.locks.applier_write, |state| {
384            let seqno_before = state.seqno;
385            state.apply_encoded_diffs(&self.cfg, &self.metrics, std::iter::once(&diff));
386            let seqno_after = state.seqno;
387            assert!(seqno_after >= seqno_before);
388
389            if seqno_before != seqno_after {
390                debug!(
391                    "applied pushed diff {}. seqno {} -> {}.",
392                    state.shard_id, seqno_before, state.seqno
393                );
394                self.shard_metrics.pubsub_push_diff_applied.inc();
395            } else {
396                debug!(
397                    "failed to apply pushed diff {}. seqno {} vs diff {}",
398                    state.shard_id, seqno_before, diff.seqno
399                );
400                if diff.seqno <= seqno_before {
401                    self.shard_metrics.pubsub_push_diff_not_applied_stale.inc();
402                } else {
403                    self.shard_metrics
404                        .pubsub_push_diff_not_applied_out_of_order
405                        .inc();
406                }
407            }
408        })
409    }
410}
411
412/// A cache of `TypedState`, shared between all machines for that shard.
413///
414/// This is shared between all machines that come out of the same
415/// [PersistClientCache], but in production there is one of those per process,
416/// so in practice, we have one copy of state per shard per process.
417///
418/// The mutex contention between commands is not an issue, because if two
419/// command for the same shard are executing concurrently, only one can win
420/// anyway, the other will retry. With the mutex, we even get to avoid the retry
421/// if the racing commands are on the same process.
422#[derive(Debug)]
423pub struct StateCache {
424    cfg: Arc<PersistConfig>,
425    pub(crate) metrics: Arc<Metrics>,
426    states: Arc<std::sync::Mutex<BTreeMap<ShardId, Arc<OnceCell<Weak<dyn DynState>>>>>>,
427    pubsub_sender: Arc<dyn PubSubSender>,
428}
429
430#[derive(Debug)]
431enum StateCacheInit {
432    Init(Arc<dyn DynState>),
433    NeedInit(Arc<OnceCell<Weak<dyn DynState>>>),
434}
435
436impl StateCache {
437    /// Returns a new StateCache.
438    pub fn new(
439        cfg: &PersistConfig,
440        metrics: Arc<Metrics>,
441        pubsub_sender: Arc<dyn PubSubSender>,
442    ) -> Self {
443        StateCache {
444            cfg: Arc::new(cfg.clone()),
445            metrics,
446            states: Default::default(),
447            pubsub_sender,
448        }
449    }
450
451    #[cfg(test)]
452    pub(crate) fn new_no_metrics() -> Self {
453        Self::new(
454            &PersistConfig::new_for_tests(),
455            Arc::new(Metrics::new(
456                &PersistConfig::new_for_tests(),
457                &MetricsRegistry::new(),
458            )),
459            Arc::new(crate::rpc::NoopPubSubSender),
460        )
461    }
462
463    pub(crate) async fn get<K, V, T, D, F, InitFn>(
464        &self,
465        shard_id: ShardId,
466        mut init_fn: InitFn,
467        diagnostics: &Diagnostics,
468    ) -> Result<Arc<LockingTypedState<K, V, T, D>>, Box<CodecMismatch>>
469    where
470        K: Debug + Codec,
471        V: Debug + Codec,
472        T: Timestamp + Lattice + Codec64 + Sync,
473        D: Monoid + Codec64,
474        F: Future<Output = Result<TypedState<K, V, T, D>, Box<CodecMismatch>>>,
475        InitFn: FnMut() -> F,
476    {
477        loop {
478            let init = {
479                let mut states = self.states.lock().expect("lock poisoned");
480                let state = states.entry(shard_id).or_default();
481                match state.get() {
482                    Some(once_val) => match once_val.upgrade() {
483                        Some(x) => StateCacheInit::Init(x),
484                        None => {
485                            // If the Weak has lost the ability to upgrade,
486                            // we've dropped the State and it's gone. Clear the
487                            // OnceCell and init a new one.
488                            *state = Arc::new(OnceCell::new());
489                            StateCacheInit::NeedInit(Arc::clone(state))
490                        }
491                    },
492                    None => StateCacheInit::NeedInit(Arc::clone(state)),
493                }
494            };
495
496            let state = match init {
497                StateCacheInit::Init(x) => x,
498                StateCacheInit::NeedInit(init_once) => {
499                    let mut did_init: Option<Arc<LockingTypedState<K, V, T, D>>> = None;
500                    let state = init_once
501                        .get_or_try_init::<Box<CodecMismatch>, _, _>(|| async {
502                            let init_res = init_fn().await;
503                            let state = Arc::new(LockingTypedState::new(
504                                shard_id,
505                                init_res?,
506                                Arc::clone(&self.metrics),
507                                Arc::clone(&self.cfg),
508                                Arc::clone(&self.pubsub_sender).subscribe(&shard_id),
509                                diagnostics,
510                            ));
511                            let ret = Arc::downgrade(&state);
512                            did_init = Some(state);
513                            let ret: Weak<dyn DynState> = ret;
514                            Ok(ret)
515                        })
516                        .await?;
517                    if let Some(x) = did_init {
518                        // We actually did the init work, don't bother casting back
519                        // the type erased and weak version. Additionally, inform
520                        // any listeners of this new state.
521                        return Ok(x);
522                    }
523                    let Some(state) = state.upgrade() else {
524                        // Race condition. Between when we first checked the
525                        // OnceCell and the `get_or_try_init` call, (1) the
526                        // initialization finished, (2) the other user dropped
527                        // the strong ref, and (3) the Arc noticed it was down
528                        // to only weak refs and dropped the value. Nothing we
529                        // can do except try again.
530                        continue;
531                    };
532                    state
533                }
534            };
535
536            match Arc::clone(&state)
537                .as_any()
538                .downcast::<LockingTypedState<K, V, T, D>>()
539            {
540                Ok(x) => return Ok(x),
541                Err(_) => {
542                    return Err(Box::new(CodecMismatch {
543                        requested: (
544                            K::codec_name(),
545                            V::codec_name(),
546                            T::codec_name(),
547                            D::codec_name(),
548                            Some(CodecConcreteType(std::any::type_name::<(K, V, T, D)>())),
549                        ),
550                        actual: state.codecs(),
551                    }));
552                }
553            }
554        }
555    }
556
557    pub(crate) fn get_state_weak(&self, shard_id: &ShardId) -> Option<Weak<dyn DynState>> {
558        self.states
559            .lock()
560            .expect("lock")
561            .get(shard_id)
562            .and_then(|x| x.get())
563            .map(Weak::clone)
564    }
565
566    #[cfg(test)]
567    fn get_cached(&self, shard_id: &ShardId) -> Option<Arc<dyn DynState>> {
568        self.states
569            .lock()
570            .expect("lock")
571            .get(shard_id)
572            .and_then(|x| x.get())
573            .and_then(|x| x.upgrade())
574    }
575
576    #[cfg(test)]
577    fn initialized_count(&self) -> usize {
578        self.states
579            .lock()
580            .expect("lock")
581            .values()
582            .filter(|x| x.initialized())
583            .count()
584    }
585
586    #[cfg(test)]
587    fn strong_count(&self) -> usize {
588        self.states
589            .lock()
590            .expect("lock")
591            .values()
592            .filter(|x| x.get().map_or(false, |x| x.upgrade().is_some()))
593            .count()
594    }
595}
596
597/// A locked decorator for TypedState that abstracts out the specific lock implementation used.
598/// Guards the private lock with public accessor fns to make locking scopes more explicit and
599/// simpler to reason about.
600pub(crate) struct LockingTypedState<K, V, T, D> {
601    shard_id: ShardId,
602    state: RwLock<TypedState<K, V, T, D>>,
603    notifier: StateWatchNotifier,
604    cfg: Arc<PersistConfig>,
605    metrics: Arc<Metrics>,
606    shard_metrics: Arc<ShardMetrics>,
607    /// A [SchemaCacheMaps<K, V>], but stored as an Any so the `: Codec` bounds
608    /// don't propagate to basically every struct in persist.
609    schema_cache: Arc<dyn Any + Send + Sync>,
610    _subscription_token: Arc<ShardSubscriptionToken>,
611}
612
613impl<K, V, T: Debug, D> Debug for LockingTypedState<K, V, T, D> {
614    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
615        let LockingTypedState {
616            shard_id,
617            state,
618            notifier,
619            cfg: _cfg,
620            metrics: _metrics,
621            shard_metrics: _shard_metrics,
622            schema_cache: _schema_cache,
623            _subscription_token,
624        } = self;
625        f.debug_struct("LockingTypedState")
626            .field("shard_id", shard_id)
627            .field("state", state)
628            .field("notifier", notifier)
629            .finish()
630    }
631}
632
633impl<K: Codec, V: Codec, T, D> LockingTypedState<K, V, T, D> {
634    fn new(
635        shard_id: ShardId,
636        initial_state: TypedState<K, V, T, D>,
637        metrics: Arc<Metrics>,
638        cfg: Arc<PersistConfig>,
639        subscription_token: Arc<ShardSubscriptionToken>,
640        diagnostics: &Diagnostics,
641    ) -> Self {
642        Self {
643            shard_id,
644            notifier: StateWatchNotifier::new(Arc::clone(&metrics)),
645            state: RwLock::new(initial_state),
646            cfg: Arc::clone(&cfg),
647            shard_metrics: metrics.shards.shard(&shard_id, &diagnostics.shard_name),
648            schema_cache: Arc::new(SchemaCacheMaps::<K, V>::new(&metrics.schema)),
649            metrics,
650            _subscription_token: subscription_token,
651        }
652    }
653
654    pub(crate) fn schema_cache(&self) -> Arc<SchemaCacheMaps<K, V>> {
655        Arc::clone(&self.schema_cache)
656            .downcast::<SchemaCacheMaps<K, V>>()
657            .expect("K and V match")
658    }
659}
660
661impl<K, V, T, D> LockingTypedState<K, V, T, D> {
662    pub(crate) fn shard_id(&self) -> &ShardId {
663        &self.shard_id
664    }
665
666    pub(crate) fn read_lock<R, F: FnMut(&TypedState<K, V, T, D>) -> R>(
667        &self,
668        metrics: &LockMetrics,
669        mut f: F,
670    ) -> R {
671        metrics.acquire_count.inc();
672        let state = match self.state.try_read() {
673            Ok(x) => x,
674            Err(TryLockError::WouldBlock) => {
675                metrics.blocking_acquire_count.inc();
676                let start = Instant::now();
677                let state = self.state.read().expect("lock poisoned");
678                metrics
679                    .blocking_seconds
680                    .inc_by(start.elapsed().as_secs_f64());
681                state
682            }
683            Err(TryLockError::Poisoned(err)) => panic!("state read lock poisoned: {}", err),
684        };
685        f(&state)
686    }
687
688    pub(crate) fn write_lock<R, F: FnOnce(&mut TypedState<K, V, T, D>) -> R>(
689        &self,
690        metrics: &LockMetrics,
691        f: F,
692    ) -> R {
693        metrics.acquire_count.inc();
694        let mut state = match self.state.try_write() {
695            Ok(x) => x,
696            Err(TryLockError::WouldBlock) => {
697                metrics.blocking_acquire_count.inc();
698                let start = Instant::now();
699                let state = self.state.write().expect("lock poisoned");
700                metrics
701                    .blocking_seconds
702                    .inc_by(start.elapsed().as_secs_f64());
703                state
704            }
705            Err(TryLockError::Poisoned(err)) => panic!("state read lock poisoned: {}", err),
706        };
707        let seqno_before = state.seqno;
708        let ret = f(&mut state);
709        let seqno_after = state.seqno;
710        debug_assert!(seqno_after >= seqno_before);
711        if seqno_after > seqno_before {
712            self.notifier.notify(seqno_after);
713        }
714        // For now, make sure to notify while under lock. It's possible to move
715        // this out of the lock window, see [StateWatchNotifier::notify].
716        drop(state);
717        ret
718    }
719
720    pub(crate) fn notifier(&self) -> &StateWatchNotifier {
721        &self.notifier
722    }
723}
724
725#[cfg(test)]
726mod tests {
727    use std::ops::Deref;
728    use std::str::FromStr;
729    use std::sync::atomic::{AtomicBool, Ordering};
730
731    use futures::stream::{FuturesUnordered, StreamExt};
732    use mz_build_info::DUMMY_BUILD_INFO;
733    use mz_ore::task::spawn;
734    use mz_ore::{assert_err, assert_none};
735
736    use super::*;
737
738    #[mz_ore::test(tokio::test)]
739    #[cfg_attr(miri, ignore)] // unsupported operation: returning ready events from epoll_wait is not yet implemented
740    async fn client_cache() {
741        let cache = PersistClientCache::new(
742            PersistConfig::new_for_tests(),
743            &MetricsRegistry::new(),
744            |_, _| PubSubClientConnection::noop(),
745        );
746        assert_eq!(cache.blob_by_uri.lock().await.len(), 0);
747        assert_eq!(cache.consensus_by_uri.lock().await.len(), 0);
748
749        // Opening a location on an empty cache saves the results.
750        let _ = cache
751            .open(PersistLocation {
752                blob_uri: SensitiveUrl::from_str("mem://blob_zero").expect("invalid URL"),
753                consensus_uri: SensitiveUrl::from_str("mem://consensus_zero").expect("invalid URL"),
754            })
755            .await
756            .expect("failed to open location");
757        assert_eq!(cache.blob_by_uri.lock().await.len(), 1);
758        assert_eq!(cache.consensus_by_uri.lock().await.len(), 1);
759
760        // Opening a location with an already opened consensus reuses it, even
761        // if the blob is different.
762        let _ = cache
763            .open(PersistLocation {
764                blob_uri: SensitiveUrl::from_str("mem://blob_one").expect("invalid URL"),
765                consensus_uri: SensitiveUrl::from_str("mem://consensus_zero").expect("invalid URL"),
766            })
767            .await
768            .expect("failed to open location");
769        assert_eq!(cache.blob_by_uri.lock().await.len(), 2);
770        assert_eq!(cache.consensus_by_uri.lock().await.len(), 1);
771
772        // Ditto the other way.
773        let _ = cache
774            .open(PersistLocation {
775                blob_uri: SensitiveUrl::from_str("mem://blob_one").expect("invalid URL"),
776                consensus_uri: SensitiveUrl::from_str("mem://consensus_one").expect("invalid URL"),
777            })
778            .await
779            .expect("failed to open location");
780        assert_eq!(cache.blob_by_uri.lock().await.len(), 2);
781        assert_eq!(cache.consensus_by_uri.lock().await.len(), 2);
782
783        // Query params and path matter, so we get new instances.
784        let _ = cache
785            .open(PersistLocation {
786                blob_uri: SensitiveUrl::from_str("mem://blob_one?foo").expect("invalid URL"),
787                consensus_uri: SensitiveUrl::from_str("mem://consensus_one/bar")
788                    .expect("invalid URL"),
789            })
790            .await
791            .expect("failed to open location");
792        assert_eq!(cache.blob_by_uri.lock().await.len(), 3);
793        assert_eq!(cache.consensus_by_uri.lock().await.len(), 3);
794
795        // User info and port also matter, so we get new instances.
796        let _ = cache
797            .open(PersistLocation {
798                blob_uri: SensitiveUrl::from_str("mem://user@blob_one").expect("invalid URL"),
799                consensus_uri: SensitiveUrl::from_str("mem://@consensus_one:123")
800                    .expect("invalid URL"),
801            })
802            .await
803            .expect("failed to open location");
804        assert_eq!(cache.blob_by_uri.lock().await.len(), 4);
805        assert_eq!(cache.consensus_by_uri.lock().await.len(), 4);
806    }
807
808    #[mz_ore::test(tokio::test)]
809    #[cfg_attr(miri, ignore)] // unsupported operation: returning ready events from epoll_wait is not yet implemented
810    async fn state_cache() {
811        mz_ore::test::init_logging();
812        fn new_state<K, V, T, D>(shard_id: ShardId) -> TypedState<K, V, T, D>
813        where
814            K: Codec,
815            V: Codec,
816            T: Timestamp + Lattice + Codec64,
817            D: Codec64,
818        {
819            TypedState::new(
820                DUMMY_BUILD_INFO.semver_version(),
821                shard_id,
822                "host".into(),
823                0,
824            )
825        }
826        fn assert_same<K, V, T, D>(
827            state1: &LockingTypedState<K, V, T, D>,
828            state2: &LockingTypedState<K, V, T, D>,
829        ) {
830            let pointer1 = format!("{:p}", state1.state.read().expect("lock").deref());
831            let pointer2 = format!("{:p}", state2.state.read().expect("lock").deref());
832            assert_eq!(pointer1, pointer2);
833        }
834
835        let s1 = ShardId::new();
836        let states = Arc::new(StateCache::new_no_metrics());
837
838        // The cache starts empty.
839        assert_eq!(states.states.lock().expect("lock").len(), 0);
840
841        // Panic'ing during init_fn .
842        let s = Arc::clone(&states);
843        let res = spawn(|| "test", async move {
844            s.get::<(), (), u64, i64, _, _>(
845                s1,
846                || async { panic!("forced panic") },
847                &Diagnostics::for_tests(),
848            )
849            .await
850        })
851        .into_tokio_handle()
852        .await;
853        assert_err!(res);
854        assert_eq!(states.initialized_count(), 0);
855
856        // Returning an error from init_fn doesn't initialize an entry in the cache.
857        let res = states
858            .get::<(), (), u64, i64, _, _>(
859                s1,
860                || async {
861                    Err(Box::new(CodecMismatch {
862                        requested: ("".into(), "".into(), "".into(), "".into(), None),
863                        actual: ("".into(), "".into(), "".into(), "".into(), None),
864                    }))
865                },
866                &Diagnostics::for_tests(),
867            )
868            .await;
869        assert_err!(res);
870        assert_eq!(states.initialized_count(), 0);
871
872        // Initialize one shard.
873        let did_work = Arc::new(AtomicBool::new(false));
874        let s1_state1 = states
875            .get::<(), (), u64, i64, _, _>(
876                s1,
877                || {
878                    let did_work = Arc::clone(&did_work);
879                    async move {
880                        did_work.store(true, Ordering::SeqCst);
881                        Ok(new_state(s1))
882                    }
883                },
884                &Diagnostics::for_tests(),
885            )
886            .await
887            .expect("should successfully initialize");
888        assert_eq!(did_work.load(Ordering::SeqCst), true);
889        assert_eq!(states.initialized_count(), 1);
890        assert_eq!(states.strong_count(), 1);
891
892        // Trying to initialize it again does no work and returns the same state.
893        let did_work = Arc::new(AtomicBool::new(false));
894        let s1_state2 = states
895            .get::<(), (), u64, i64, _, _>(
896                s1,
897                || {
898                    let did_work = Arc::clone(&did_work);
899                    async move {
900                        did_work.store(true, Ordering::SeqCst);
901                        did_work.store(true, Ordering::SeqCst);
902                        Ok(new_state(s1))
903                    }
904                },
905                &Diagnostics::for_tests(),
906            )
907            .await
908            .expect("should successfully initialize");
909        assert_eq!(did_work.load(Ordering::SeqCst), false);
910        assert_eq!(states.initialized_count(), 1);
911        assert_eq!(states.strong_count(), 1);
912        assert_same(&s1_state1, &s1_state2);
913
914        // Trying to initialize with different types doesn't work.
915        let did_work = Arc::new(AtomicBool::new(false));
916        let res = states
917            .get::<String, (), u64, i64, _, _>(
918                s1,
919                || {
920                    let did_work = Arc::clone(&did_work);
921                    async move {
922                        did_work.store(true, Ordering::SeqCst);
923                        Ok(new_state(s1))
924                    }
925                },
926                &Diagnostics::for_tests(),
927            )
928            .await;
929        assert_eq!(did_work.load(Ordering::SeqCst), false);
930        assert_eq!(
931            format!("{}", res.expect_err("types shouldn't match")),
932            "requested codecs (\"String\", \"()\", \"u64\", \"i64\", Some(CodecConcreteType(\"(alloc::string::String, (), u64, i64)\"))) did not match ones in durable storage (\"()\", \"()\", \"u64\", \"i64\", Some(CodecConcreteType(\"((), (), u64, i64)\")))"
933        );
934        assert_eq!(states.initialized_count(), 1);
935        assert_eq!(states.strong_count(), 1);
936
937        // We can add a shard of a different type.
938        let s2 = ShardId::new();
939        let s2_state1 = states
940            .get::<String, (), u64, i64, _, _>(
941                s2,
942                || async { Ok(new_state(s2)) },
943                &Diagnostics::for_tests(),
944            )
945            .await
946            .expect("should successfully initialize");
947        assert_eq!(states.initialized_count(), 2);
948        assert_eq!(states.strong_count(), 2);
949        let s2_state2 = states
950            .get::<String, (), u64, i64, _, _>(
951                s2,
952                || async { Ok(new_state(s2)) },
953                &Diagnostics::for_tests(),
954            )
955            .await
956            .expect("should successfully initialize");
957        assert_same(&s2_state1, &s2_state2);
958
959        // The cache holds weak references to State so we reclaim memory if the
960        // shards stops being used.
961        drop(s1_state1);
962        assert_eq!(states.strong_count(), 2);
963        drop(s1_state2);
964        assert_eq!(states.strong_count(), 1);
965        assert_eq!(states.initialized_count(), 2);
966        assert_none!(states.get_cached(&s1));
967
968        // But we can re-init that shard if necessary.
969        let s1_state1 = states
970            .get::<(), (), u64, i64, _, _>(
971                s1,
972                || async { Ok(new_state(s1)) },
973                &Diagnostics::for_tests(),
974            )
975            .await
976            .expect("should successfully initialize");
977        assert_eq!(states.initialized_count(), 2);
978        assert_eq!(states.strong_count(), 2);
979        drop(s1_state1);
980        assert_eq!(states.strong_count(), 1);
981    }
982
983    #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
984    #[cfg_attr(miri, ignore)] // too slow
985    async fn state_cache_concurrency() {
986        mz_ore::test::init_logging();
987
988        const COUNT: usize = 1000;
989        let id = ShardId::new();
990        let cache = StateCache::new_no_metrics();
991        let diagnostics = Diagnostics::for_tests();
992
993        let mut futures = (0..COUNT)
994            .map(|_| {
995                cache.get::<(), (), u64, i64, _, _>(
996                    id,
997                    || async {
998                        Ok(TypedState::new(
999                            DUMMY_BUILD_INFO.semver_version(),
1000                            id,
1001                            "host".into(),
1002                            0,
1003                        ))
1004                    },
1005                    &diagnostics,
1006                )
1007            })
1008            .collect::<FuturesUnordered<_>>();
1009
1010        for _ in 0..COUNT {
1011            let _ = futures.next().await.unwrap();
1012        }
1013    }
1014}