Skip to main content

mz_persist_client/
cache.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! A cache of [PersistClient]s indexed by [PersistLocation]s.
11
12use std::any::Any;
13use std::collections::BTreeMap;
14use std::collections::btree_map::Entry;
15use std::fmt::Debug;
16use std::future::Future;
17use std::sync::{Arc, RwLock, TryLockError, Weak};
18use std::time::{Duration, Instant};
19
20use differential_dataflow::difference::Monoid;
21use differential_dataflow::lattice::Lattice;
22use mz_dyncfg::Config;
23use mz_ore::instrument;
24use mz_ore::metrics::MetricsRegistry;
25use mz_ore::task::{AbortOnDropHandle, JoinHandle};
26use mz_ore::url::SensitiveUrl;
27use mz_persist::cfg::{BlobConfig, ConsensusConfig};
28use mz_persist::location::{
29    BLOB_GET_LIVENESS_KEY, Blob, CONSENSUS_HEAD_LIVENESS_KEY, Consensus, ExternalError, Tasked,
30    VersionedData,
31};
32use mz_persist_types::{Codec, Codec64};
33use timely::progress::Timestamp;
34use tokio::sync::{Mutex, OnceCell};
35use tracing::debug;
36
37use crate::async_runtime::IsolatedRuntime;
38use crate::error::{CodecConcreteType, CodecMismatch};
39use crate::internal::cache::BlobMemCache;
40use crate::internal::machine::retry_external;
41use crate::internal::metrics::{LockMetrics, Metrics, MetricsBlob, MetricsConsensus, ShardMetrics};
42use crate::internal::state::TypedState;
43use crate::internal::watch::{AwaitableState, StateWatchNotifier};
44use crate::rpc::{PubSubClientConnection, PubSubSender, ShardSubscriptionToken};
45use crate::schema::SchemaCacheMaps;
46use crate::{Diagnostics, PersistClient, PersistConfig, PersistLocation, ShardId};
47
48/// A cache of [PersistClient]s indexed by [PersistLocation]s.
49///
50/// There should be at most one of these per process. All production
51/// PersistClients should be created through this cache.
52///
53/// This is because, in production, persist is heavily limited by the number of
54/// server-side Postgres/Aurora connections. This cache allows PersistClients to
55/// share, for example, these Postgres connections.
56#[derive(Debug)]
57pub struct PersistClientCache {
58    /// The tunable knobs for persist.
59    pub cfg: PersistConfig,
60    pub(crate) metrics: Arc<Metrics>,
61    blob_by_uri: Mutex<BTreeMap<SensitiveUrl, (RttLatencyTask, Arc<dyn Blob>)>>,
62    consensus_by_uri: Mutex<BTreeMap<SensitiveUrl, (RttLatencyTask, Arc<dyn Consensus>)>>,
63    isolated_runtime: Arc<IsolatedRuntime>,
64    pub(crate) state_cache: Arc<StateCache>,
65    pubsub_sender: Arc<dyn PubSubSender>,
66    _pubsub_receiver_task: JoinHandle<()>,
67}
68
69#[derive(Debug)]
70struct RttLatencyTask(#[allow(dead_code)] AbortOnDropHandle<()>);
71
72impl PersistClientCache {
73    /// Returns a new [PersistClientCache].
74    pub fn new<F>(cfg: PersistConfig, registry: &MetricsRegistry, pubsub: F) -> Self
75    where
76        F: FnOnce(&PersistConfig, Arc<Metrics>) -> PubSubClientConnection,
77    {
78        let metrics = Arc::new(Metrics::new(&cfg, registry));
79        let pubsub_client = pubsub(&cfg, Arc::clone(&metrics));
80
81        let state_cache = Arc::new(StateCache::new(
82            &cfg,
83            Arc::clone(&metrics),
84            Arc::clone(&pubsub_client.sender),
85        ));
86        let _pubsub_receiver_task = crate::rpc::subscribe_state_cache_to_pubsub(
87            Arc::clone(&state_cache),
88            pubsub_client.receiver,
89        );
90        let isolated_runtime =
91            IsolatedRuntime::new(registry, Some(cfg.isolated_runtime_worker_threads));
92
93        PersistClientCache {
94            cfg,
95            metrics,
96            blob_by_uri: Mutex::new(BTreeMap::new()),
97            consensus_by_uri: Mutex::new(BTreeMap::new()),
98            isolated_runtime: Arc::new(isolated_runtime),
99            state_cache,
100            pubsub_sender: pubsub_client.sender,
101            _pubsub_receiver_task,
102        }
103    }
104
105    /// A test helper that returns a [PersistClientCache] disconnected from
106    /// metrics.
107    pub fn new_no_metrics() -> Self {
108        Self::new(
109            PersistConfig::new_for_tests(),
110            &MetricsRegistry::new(),
111            |_, _| PubSubClientConnection::noop(),
112        )
113    }
114
115    #[cfg(feature = "turmoil")]
116    /// Create a [PersistClientCache] for use in turmoil tests.
117    ///
118    /// Turmoil wants to run all software under test in a single thread, so we disable the
119    /// (multi-threaded) isolated runtime.
120    pub fn new_for_turmoil() -> Self {
121        use crate::rpc::NoopPubSubSender;
122
123        let cfg = PersistConfig::new_for_tests();
124        let metrics = Arc::new(Metrics::new(&cfg, &MetricsRegistry::new()));
125
126        let pubsub_sender: Arc<dyn PubSubSender> = Arc::new(NoopPubSubSender);
127        let _pubsub_receiver_task = mz_ore::task::spawn(|| "noop", async {});
128
129        let state_cache = Arc::new(StateCache::new(
130            &cfg,
131            Arc::clone(&metrics),
132            Arc::clone(&pubsub_sender),
133        ));
134        let isolated_runtime = IsolatedRuntime::new_disabled();
135
136        PersistClientCache {
137            cfg,
138            metrics,
139            blob_by_uri: Mutex::new(BTreeMap::new()),
140            consensus_by_uri: Mutex::new(BTreeMap::new()),
141            isolated_runtime: Arc::new(isolated_runtime),
142            state_cache,
143            pubsub_sender,
144            _pubsub_receiver_task,
145        }
146    }
147
148    /// Returns the [PersistConfig] being used by this cache.
149    pub fn cfg(&self) -> &PersistConfig {
150        &self.cfg
151    }
152
153    /// Returns persist `Metrics`.
154    pub fn metrics(&self) -> &Arc<Metrics> {
155        &self.metrics
156    }
157
158    /// Returns `ShardMetrics` for the given shard.
159    pub fn shard_metrics(&self, shard_id: &ShardId, name: &str) -> Arc<ShardMetrics> {
160        self.metrics.shards.shard(shard_id, name)
161    }
162
163    /// Clears the state cache, allowing for tests with disconnected states.
164    ///
165    /// Only exposed for testing.
166    pub fn clear_state_cache(&mut self) {
167        self.state_cache = Arc::new(StateCache::new(
168            &self.cfg,
169            Arc::clone(&self.metrics),
170            Arc::clone(&self.pubsub_sender),
171        ))
172    }
173
174    /// Returns a new [PersistClient] for interfacing with persist shards made
175    /// durable to the given [PersistLocation].
176    ///
177    /// The same `location` may be used concurrently from multiple processes.
178    #[instrument(level = "debug")]
179    pub async fn open(&self, location: PersistLocation) -> Result<PersistClient, ExternalError> {
180        let blob = self.open_blob(location.blob_uri).await?;
181        let consensus = self.open_consensus(location.consensus_uri).await?;
182        PersistClient::new(
183            self.cfg.clone(),
184            blob,
185            consensus,
186            Arc::clone(&self.metrics),
187            Arc::clone(&self.isolated_runtime),
188            Arc::clone(&self.state_cache),
189            Arc::clone(&self.pubsub_sender),
190        )
191    }
192
193    // No sense in measuring rtt latencies more often than this.
194    const PROMETHEUS_SCRAPE_INTERVAL: Duration = Duration::from_secs(60);
195
196    async fn open_consensus(
197        &self,
198        consensus_uri: SensitiveUrl,
199    ) -> Result<Arc<dyn Consensus>, ExternalError> {
200        let mut consensus_by_uri = self.consensus_by_uri.lock().await;
201        let consensus = match consensus_by_uri.entry(consensus_uri) {
202            Entry::Occupied(x) => Arc::clone(&x.get().1),
203            Entry::Vacant(x) => {
204                // Intentionally hold the lock, so we don't double connect under
205                // concurrency.
206                let consensus = ConsensusConfig::try_from(
207                    x.key(),
208                    Box::new(self.cfg.clone()),
209                    self.metrics.postgres_consensus.clone(),
210                    Arc::clone(&self.cfg().configs),
211                )?;
212                let consensus =
213                    retry_external(&self.metrics.retries.external.consensus_open, || {
214                        consensus.clone().open()
215                    })
216                    .await;
217                let consensus =
218                    Arc::new(MetricsConsensus::new(consensus, Arc::clone(&self.metrics)));
219                let consensus = Arc::new(Tasked(consensus));
220                let task = consensus_rtt_latency_task(
221                    Arc::clone(&consensus),
222                    Arc::clone(&self.metrics),
223                    Self::PROMETHEUS_SCRAPE_INTERVAL,
224                )
225                .await;
226                Arc::clone(
227                    &x.insert((RttLatencyTask(task.abort_on_drop()), consensus))
228                        .1,
229                )
230            }
231        };
232        Ok(consensus)
233    }
234
235    async fn open_blob(&self, blob_uri: SensitiveUrl) -> Result<Arc<dyn Blob>, ExternalError> {
236        let mut blob_by_uri = self.blob_by_uri.lock().await;
237        let blob = match blob_by_uri.entry(blob_uri) {
238            Entry::Occupied(x) => Arc::clone(&x.get().1),
239            Entry::Vacant(x) => {
240                // Intentionally hold the lock, so we don't double connect under
241                // concurrency.
242                let blob = BlobConfig::try_from(
243                    x.key(),
244                    Box::new(self.cfg.clone()),
245                    self.metrics.s3_blob.clone(),
246                )
247                .await?;
248                let blob = retry_external(&self.metrics.retries.external.blob_open, || {
249                    blob.clone().open()
250                })
251                .await;
252                let blob = Arc::new(MetricsBlob::new(blob, Arc::clone(&self.metrics)));
253                let blob = Arc::new(Tasked(blob));
254                let task = blob_rtt_latency_task(
255                    Arc::clone(&blob),
256                    Arc::clone(&self.metrics),
257                    Self::PROMETHEUS_SCRAPE_INTERVAL,
258                )
259                .await;
260                // This is intentionally "outside" (wrapping) MetricsBlob so
261                // that we don't include cached responses in blob metrics.
262                let blob = BlobMemCache::new(&self.cfg, Arc::clone(&self.metrics), blob);
263                Arc::clone(&x.insert((RttLatencyTask(task.abort_on_drop()), blob)).1)
264            }
265        };
266        Ok(blob)
267    }
268}
269
270/// Starts a task to periodically measure the persist-observed latency to
271/// consensus.
272///
273/// This is a task, rather than something like looking at the latencies of prod
274/// traffic, so that we minimize any issues around Futures not being polled
275/// promptly (as can and does happen with the Timely-polled Futures).
276///
277/// The caller is responsible for shutdown via aborting the `JoinHandle`.
278///
279/// No matter whether we wrap MetricsConsensus before or after we start up the
280/// rtt latency task, there's the possibility for it being confusing at some
281/// point. Err on the side of more data (including the latency measurements) to
282/// start.
283#[allow(clippy::unused_async)]
284async fn blob_rtt_latency_task(
285    blob: Arc<Tasked<MetricsBlob>>,
286    metrics: Arc<Metrics>,
287    measurement_interval: Duration,
288) -> JoinHandle<()> {
289    mz_ore::task::spawn(|| "persist::blob_rtt_latency", async move {
290        // Use the tokio Instant for next_measurement because the reclock tests
291        // mess with the tokio sleep clock.
292        let mut next_measurement = tokio::time::Instant::now();
293        loop {
294            tokio::time::sleep_until(next_measurement).await;
295            let start = Instant::now();
296            match blob.get(BLOB_GET_LIVENESS_KEY).await {
297                Ok(_) => {
298                    metrics.blob.rtt_latency.set(start.elapsed().as_secs_f64());
299                }
300                Err(_) => {
301                    // Don't spam retries if this returns an error. We're
302                    // guaranteed by the method signature that we've already got
303                    // metrics coverage of these, so we'll count the errors.
304                }
305            }
306            next_measurement = tokio::time::Instant::now() + measurement_interval;
307        }
308    })
309}
310
311/// Starts a task to periodically measure the persist-observed latency to
312/// consensus.
313///
314/// This is a task, rather than something like looking at the latencies of prod
315/// traffic, so that we minimize any issues around Futures not being polled
316/// promptly (as can and does happen with the Timely-polled Futures).
317///
318/// The caller is responsible for shutdown via aborting the `JoinHandle`.
319///
320/// No matter whether we wrap MetricsConsensus before or after we start up the
321/// rtt latency task, there's the possibility for it being confusing at some
322/// point. Err on the side of more data (including the latency measurements) to
323/// start.
324#[allow(clippy::unused_async)]
325async fn consensus_rtt_latency_task(
326    consensus: Arc<Tasked<MetricsConsensus>>,
327    metrics: Arc<Metrics>,
328    measurement_interval: Duration,
329) -> JoinHandle<()> {
330    mz_ore::task::spawn(|| "persist::consensus_rtt_latency", async move {
331        // Use the tokio Instant for next_measurement because the reclock tests
332        // mess with the tokio sleep clock.
333        let mut next_measurement = tokio::time::Instant::now();
334        loop {
335            tokio::time::sleep_until(next_measurement).await;
336            let start = Instant::now();
337            match consensus.head(CONSENSUS_HEAD_LIVENESS_KEY).await {
338                Ok(_) => {
339                    metrics
340                        .consensus
341                        .rtt_latency
342                        .set(start.elapsed().as_secs_f64());
343                }
344                Err(_) => {
345                    // Don't spam retries if this returns an error. We're
346                    // guaranteed by the method signature that we've already got
347                    // metrics coverage of these, so we'll count the errors.
348                }
349            }
350            next_measurement = tokio::time::Instant::now() + measurement_interval;
351        }
352    })
353}
354
355pub(crate) trait DynState: Debug + Send + Sync {
356    fn codecs(&self) -> (String, String, String, String, Option<CodecConcreteType>);
357    fn as_any(self: Arc<Self>) -> Arc<dyn Any + Send + Sync>;
358    fn push_diff(&self, diff: VersionedData);
359}
360
361impl<K, V, T, D> DynState for LockingTypedState<K, V, T, D>
362where
363    K: Codec,
364    V: Codec,
365    T: Timestamp + Lattice + Codec64 + Sync,
366    D: Codec64,
367{
368    fn codecs(&self) -> (String, String, String, String, Option<CodecConcreteType>) {
369        (
370            K::codec_name(),
371            V::codec_name(),
372            T::codec_name(),
373            D::codec_name(),
374            Some(CodecConcreteType(std::any::type_name::<(K, V, T, D)>())),
375        )
376    }
377
378    fn as_any(self: Arc<Self>) -> Arc<dyn Any + Send + Sync> {
379        self
380    }
381
382    fn push_diff(&self, diff: VersionedData) {
383        self.write_lock(&self.metrics.locks.applier_write, |state| {
384            let seqno_before = state.seqno;
385            state.apply_encoded_diffs(&self.cfg, &self.metrics, std::iter::once(&diff));
386            let seqno_after = state.seqno;
387            assert!(seqno_after >= seqno_before);
388
389            if seqno_before != seqno_after {
390                debug!(
391                    "applied pushed diff {}. seqno {} -> {}.",
392                    state.shard_id, seqno_before, state.seqno
393                );
394                self.shard_metrics.pubsub_push_diff_applied.inc();
395            } else {
396                debug!(
397                    "failed to apply pushed diff {}. seqno {} vs diff {}",
398                    state.shard_id, seqno_before, diff.seqno
399                );
400                if diff.seqno <= seqno_before {
401                    self.shard_metrics.pubsub_push_diff_not_applied_stale.inc();
402                } else {
403                    self.shard_metrics
404                        .pubsub_push_diff_not_applied_out_of_order
405                        .inc();
406                }
407            }
408        })
409    }
410}
411
412/// A cache of `TypedState`, shared between all machines for that shard.
413///
414/// This is shared between all machines that come out of the same
415/// [PersistClientCache], but in production there is one of those per process,
416/// so in practice, we have one copy of state per shard per process.
417///
418/// The mutex contention between commands is not an issue, because if two
419/// command for the same shard are executing concurrently, only one can win
420/// anyway, the other will retry. With the mutex, we even get to avoid the retry
421/// if the racing commands are on the same process.
422#[derive(Debug)]
423pub struct StateCache {
424    cfg: Arc<PersistConfig>,
425    pub(crate) metrics: Arc<Metrics>,
426    states: Arc<std::sync::Mutex<BTreeMap<ShardId, Arc<OnceCell<Weak<dyn DynState>>>>>>,
427    pubsub_sender: Arc<dyn PubSubSender>,
428}
429
430#[derive(Debug)]
431enum StateCacheInit {
432    Init(Arc<dyn DynState>),
433    NeedInit(Arc<OnceCell<Weak<dyn DynState>>>),
434}
435
436impl StateCache {
437    /// Returns a new StateCache.
438    pub fn new(
439        cfg: &PersistConfig,
440        metrics: Arc<Metrics>,
441        pubsub_sender: Arc<dyn PubSubSender>,
442    ) -> Self {
443        StateCache {
444            cfg: Arc::new(cfg.clone()),
445            metrics,
446            states: Default::default(),
447            pubsub_sender,
448        }
449    }
450
451    #[cfg(test)]
452    pub(crate) fn new_no_metrics() -> Self {
453        Self::new(
454            &PersistConfig::new_for_tests(),
455            Arc::new(Metrics::new(
456                &PersistConfig::new_for_tests(),
457                &MetricsRegistry::new(),
458            )),
459            Arc::new(crate::rpc::NoopPubSubSender),
460        )
461    }
462
463    pub(crate) async fn get<K, V, T, D, F, InitFn>(
464        &self,
465        shard_id: ShardId,
466        mut init_fn: InitFn,
467        diagnostics: &Diagnostics,
468    ) -> Result<Arc<LockingTypedState<K, V, T, D>>, Box<CodecMismatch>>
469    where
470        K: Debug + Codec,
471        V: Debug + Codec,
472        T: Timestamp + Lattice + Codec64 + Sync,
473        D: Monoid + Codec64,
474        F: Future<Output = Result<TypedState<K, V, T, D>, Box<CodecMismatch>>>,
475        InitFn: FnMut() -> F,
476    {
477        loop {
478            let init = {
479                let mut states = self.states.lock().expect("lock poisoned");
480                let state = states.entry(shard_id).or_default();
481                match state.get() {
482                    Some(once_val) => match once_val.upgrade() {
483                        Some(x) => StateCacheInit::Init(x),
484                        None => {
485                            // If the Weak has lost the ability to upgrade,
486                            // we've dropped the State and it's gone. Clear the
487                            // OnceCell and init a new one.
488                            *state = Arc::new(OnceCell::new());
489                            StateCacheInit::NeedInit(Arc::clone(state))
490                        }
491                    },
492                    None => StateCacheInit::NeedInit(Arc::clone(state)),
493                }
494            };
495
496            let state = match init {
497                StateCacheInit::Init(x) => x,
498                StateCacheInit::NeedInit(init_once) => {
499                    let mut did_init: Option<Arc<LockingTypedState<K, V, T, D>>> = None;
500                    let state = init_once
501                        .get_or_try_init::<Box<CodecMismatch>, _, _>(|| async {
502                            let init_res = init_fn().await;
503                            let state = Arc::new(LockingTypedState::new(
504                                shard_id,
505                                init_res?,
506                                Arc::clone(&self.metrics),
507                                Arc::clone(&self.cfg),
508                                Arc::clone(&self.pubsub_sender).subscribe(&shard_id),
509                                diagnostics,
510                            ));
511                            let ret = Arc::downgrade(&state);
512                            did_init = Some(state);
513                            let ret: Weak<dyn DynState> = ret;
514                            Ok(ret)
515                        })
516                        .await?;
517                    if let Some(x) = did_init {
518                        // We actually did the init work, don't bother casting back
519                        // the type erased and weak version. Additionally, inform
520                        // any listeners of this new state.
521                        return Ok(x);
522                    }
523                    let Some(state) = state.upgrade() else {
524                        // Race condition. Between when we first checked the
525                        // OnceCell and the `get_or_try_init` call, (1) the
526                        // initialization finished, (2) the other user dropped
527                        // the strong ref, and (3) the Arc noticed it was down
528                        // to only weak refs and dropped the value. Nothing we
529                        // can do except try again.
530                        continue;
531                    };
532                    state
533                }
534            };
535
536            match Arc::clone(&state)
537                .as_any()
538                .downcast::<LockingTypedState<K, V, T, D>>()
539            {
540                Ok(x) => return Ok(x),
541                Err(_) => {
542                    return Err(Box::new(CodecMismatch {
543                        requested: (
544                            K::codec_name(),
545                            V::codec_name(),
546                            T::codec_name(),
547                            D::codec_name(),
548                            Some(CodecConcreteType(std::any::type_name::<(K, V, T, D)>())),
549                        ),
550                        actual: state.codecs(),
551                    }));
552                }
553            }
554        }
555    }
556
557    pub(crate) fn get_state_weak(&self, shard_id: &ShardId) -> Option<Weak<dyn DynState>> {
558        self.states
559            .lock()
560            .expect("lock")
561            .get(shard_id)
562            .and_then(|x| x.get())
563            .map(Weak::clone)
564    }
565
566    #[cfg(test)]
567    fn get_cached(&self, shard_id: &ShardId) -> Option<Arc<dyn DynState>> {
568        self.states
569            .lock()
570            .expect("lock")
571            .get(shard_id)
572            .and_then(|x| x.get())
573            .and_then(|x| x.upgrade())
574    }
575
576    #[cfg(test)]
577    fn initialized_count(&self) -> usize {
578        self.states
579            .lock()
580            .expect("lock")
581            .values()
582            .filter(|x| x.initialized())
583            .count()
584    }
585
586    #[cfg(test)]
587    fn strong_count(&self) -> usize {
588        self.states
589            .lock()
590            .expect("lock")
591            .values()
592            .filter(|x| x.get().map_or(false, |x| x.upgrade().is_some()))
593            .count()
594    }
595}
596
597/// A locked decorator for TypedState that abstracts out the specific lock implementation used.
598/// Guards the private lock with public accessor fns to make locking scopes more explicit and
599/// simpler to reason about.
600pub(crate) struct LockingTypedState<K, V, T, D> {
601    shard_id: ShardId,
602    state: RwLock<TypedState<K, V, T, D>>,
603    notifier: StateWatchNotifier,
604    cfg: Arc<PersistConfig>,
605    metrics: Arc<Metrics>,
606    shard_metrics: Arc<ShardMetrics>,
607    update_semaphore: AwaitableState<Option<tokio::time::Instant>>,
608    /// A [SchemaCacheMaps<K, V>], but stored as an Any so the `: Codec` bounds
609    /// don't propagate to basically every struct in persist.
610    schema_cache: Arc<dyn Any + Send + Sync>,
611    _subscription_token: Arc<ShardSubscriptionToken>,
612}
613
614impl<K, V, T: Debug, D> Debug for LockingTypedState<K, V, T, D> {
615    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
616        let LockingTypedState {
617            shard_id,
618            state,
619            notifier,
620            cfg: _cfg,
621            metrics: _metrics,
622            shard_metrics: _shard_metrics,
623            update_semaphore: _,
624            schema_cache: _schema_cache,
625            _subscription_token,
626        } = self;
627        f.debug_struct("LockingTypedState")
628            .field("shard_id", shard_id)
629            .field("state", state)
630            .field("notifier", notifier)
631            .finish()
632    }
633}
634
635impl<K: Codec, V: Codec, T, D> LockingTypedState<K, V, T, D> {
636    fn new(
637        shard_id: ShardId,
638        initial_state: TypedState<K, V, T, D>,
639        metrics: Arc<Metrics>,
640        cfg: Arc<PersistConfig>,
641        subscription_token: Arc<ShardSubscriptionToken>,
642        diagnostics: &Diagnostics,
643    ) -> Self {
644        Self {
645            shard_id,
646            notifier: StateWatchNotifier::new(Arc::clone(&metrics)),
647            state: RwLock::new(initial_state),
648            cfg: Arc::clone(&cfg),
649            shard_metrics: metrics.shards.shard(&shard_id, &diagnostics.shard_name),
650            update_semaphore: AwaitableState::new(None),
651            schema_cache: Arc::new(SchemaCacheMaps::<K, V>::new(&metrics.schema)),
652            metrics,
653            _subscription_token: subscription_token,
654        }
655    }
656
657    pub(crate) fn schema_cache(&self) -> Arc<SchemaCacheMaps<K, V>> {
658        Arc::clone(&self.schema_cache)
659            .downcast::<SchemaCacheMaps<K, V>>()
660            .expect("K and V match")
661    }
662}
663
664pub(crate) const STATE_UPDATE_LEASE_TIMEOUT: Config<Duration> = Config::new(
665    "persist_state_update_lease_timeout",
666    Duration::from_secs(1),
667    "The amount of time for a command to wait for a previous command to finish before executing. \
668        (If zero, commands will not wait for others to complete.) Higher values reduce database contention \
669        at the cost of higher worst-case latencies for individual requests.",
670);
671
672impl<K, V, T, D> LockingTypedState<K, V, T, D> {
673    pub(crate) fn shard_id(&self) -> &ShardId {
674        &self.shard_id
675    }
676
677    pub(crate) fn read_lock<R, F: FnMut(&TypedState<K, V, T, D>) -> R>(
678        &self,
679        metrics: &LockMetrics,
680        mut f: F,
681    ) -> R {
682        metrics.acquire_count.inc();
683        let state = match self.state.try_read() {
684            Ok(x) => x,
685            Err(TryLockError::WouldBlock) => {
686                metrics.blocking_acquire_count.inc();
687                let start = Instant::now();
688                let state = self.state.read().expect("lock poisoned");
689                metrics
690                    .blocking_seconds
691                    .inc_by(start.elapsed().as_secs_f64());
692                state
693            }
694            Err(TryLockError::Poisoned(err)) => panic!("state read lock poisoned: {}", err),
695        };
696        f(&state)
697    }
698
699    pub(crate) fn write_lock<R, F: FnOnce(&mut TypedState<K, V, T, D>) -> R>(
700        &self,
701        metrics: &LockMetrics,
702        f: F,
703    ) -> R {
704        metrics.acquire_count.inc();
705        let mut state = match self.state.try_write() {
706            Ok(x) => x,
707            Err(TryLockError::WouldBlock) => {
708                metrics.blocking_acquire_count.inc();
709                let start = Instant::now();
710                let state = self.state.write().expect("lock poisoned");
711                metrics
712                    .blocking_seconds
713                    .inc_by(start.elapsed().as_secs_f64());
714                state
715            }
716            Err(TryLockError::Poisoned(err)) => panic!("state read lock poisoned: {}", err),
717        };
718        let seqno_before = state.seqno;
719        let ret = f(&mut state);
720        let seqno_after = state.seqno;
721        debug_assert!(seqno_after >= seqno_before);
722        if seqno_after > seqno_before {
723            self.notifier.notify(seqno_after);
724        }
725        // For now, make sure to notify while under lock. It's possible to move
726        // this out of the lock window, see [StateWatchNotifier::notify].
727        drop(state);
728        ret
729    }
730
731    /// We want to _mostly_ just attempt a single CaS against the same state at once, since
732    /// only one concurrent CaS can succeed. However, we also want to guard against a
733    /// single hung update blocking all progress globally. We manage this with a shared state,
734    /// tracking whether a request is in flight and when it times out. If the timeout is never hit,
735    /// this behaves like a semaphore with limit 1... but if our requests _are_ timing out, future
736    /// requests will only wait for a bounded time before retrying, and one of those retries will
737    /// be able to claim that lease and make progress.
738    pub(crate) async fn lease_for_update(&self) -> impl Drop {
739        use tokio::time::Instant;
740
741        let timeout = STATE_UPDATE_LEASE_TIMEOUT.get(&self.cfg);
742
743        struct DropLease(Option<(AwaitableState<Option<Instant>>, Instant)>);
744
745        impl Drop for DropLease {
746            fn drop(&mut self) {
747                if let Some((state, time)) = self.0.take() {
748                    // Clear the timeout if it hasn't changed since we set it.
749                    state.maybe_modify(|s| {
750                        if s.is_some_and(|t| t == time) {
751                            *s.get_mut() = None;
752                        }
753                    })
754                }
755            }
756        }
757
758        // Special case: if the timeout is set to zero, go ahead without taking a lease.
759        if timeout.is_zero() {
760            return DropLease(None);
761        }
762
763        let timeout_state = self.update_semaphore.clone();
764        loop {
765            let now = tokio::time::Instant::now();
766            let expires_at = now + timeout;
767            // Claim the lease if there isn't one, or if the current lease has expired.
768            let maybe_leased = timeout_state.maybe_modify(|state| {
769                if let Some(other_expires_at) = **state
770                    && other_expires_at > now
771                {
772                    // Still locked: sleep until the deadline and try again.
773                    Err(other_expires_at)
774                } else {
775                    *state.get_mut() = Some(expires_at);
776                    Ok(())
777                }
778            });
779
780            match maybe_leased {
781                Ok(()) => {
782                    break DropLease(Some((timeout_state, expires_at)));
783                }
784                Err(other_expires_at) => {
785                    // Wait until either the lease has dropped or timed out, whichever is first.
786                    // If there are a lot of clients trying to update the same state, this may
787                    // cause significant lock contention... but the lock is only briefly held,
788                    // and anyways that's still cheaper than contending on the remote database.
789                    let _ = tokio::time::timeout_at(
790                        other_expires_at,
791                        timeout_state.wait_while(|s| s.is_some()),
792                    )
793                    .await;
794                }
795            }
796        }
797    }
798
799    pub(crate) fn notifier(&self) -> &StateWatchNotifier {
800        &self.notifier
801    }
802}
803
804#[cfg(test)]
805mod tests {
806    use std::ops::Deref;
807    use std::pin::pin;
808    use std::str::FromStr;
809    use std::sync::atomic::{AtomicBool, Ordering};
810
811    use super::*;
812    use crate::rpc::NoopPubSubSender;
813    use futures::stream::{FuturesUnordered, StreamExt};
814    use mz_build_info::DUMMY_BUILD_INFO;
815    use mz_ore::task::spawn;
816    use mz_ore::{assert_err, assert_none};
817    use tokio::sync::oneshot;
818
819    #[mz_ore::test(tokio::test)]
820    #[cfg_attr(miri, ignore)] // unsupported operation: returning ready events from epoll_wait is not yet implemented
821    async fn client_cache() {
822        let cache = PersistClientCache::new(
823            PersistConfig::new_for_tests(),
824            &MetricsRegistry::new(),
825            |_, _| PubSubClientConnection::noop(),
826        );
827        assert_eq!(cache.blob_by_uri.lock().await.len(), 0);
828        assert_eq!(cache.consensus_by_uri.lock().await.len(), 0);
829
830        // Opening a location on an empty cache saves the results.
831        let _ = cache
832            .open(PersistLocation {
833                blob_uri: SensitiveUrl::from_str("mem://blob_zero").expect("invalid URL"),
834                consensus_uri: SensitiveUrl::from_str("mem://consensus_zero").expect("invalid URL"),
835            })
836            .await
837            .expect("failed to open location");
838        assert_eq!(cache.blob_by_uri.lock().await.len(), 1);
839        assert_eq!(cache.consensus_by_uri.lock().await.len(), 1);
840
841        // Opening a location with an already opened consensus reuses it, even
842        // if the blob is different.
843        let _ = cache
844            .open(PersistLocation {
845                blob_uri: SensitiveUrl::from_str("mem://blob_one").expect("invalid URL"),
846                consensus_uri: SensitiveUrl::from_str("mem://consensus_zero").expect("invalid URL"),
847            })
848            .await
849            .expect("failed to open location");
850        assert_eq!(cache.blob_by_uri.lock().await.len(), 2);
851        assert_eq!(cache.consensus_by_uri.lock().await.len(), 1);
852
853        // Ditto the other way.
854        let _ = cache
855            .open(PersistLocation {
856                blob_uri: SensitiveUrl::from_str("mem://blob_one").expect("invalid URL"),
857                consensus_uri: SensitiveUrl::from_str("mem://consensus_one").expect("invalid URL"),
858            })
859            .await
860            .expect("failed to open location");
861        assert_eq!(cache.blob_by_uri.lock().await.len(), 2);
862        assert_eq!(cache.consensus_by_uri.lock().await.len(), 2);
863
864        // Query params and path matter, so we get new instances.
865        let _ = cache
866            .open(PersistLocation {
867                blob_uri: SensitiveUrl::from_str("mem://blob_one?foo").expect("invalid URL"),
868                consensus_uri: SensitiveUrl::from_str("mem://consensus_one/bar")
869                    .expect("invalid URL"),
870            })
871            .await
872            .expect("failed to open location");
873        assert_eq!(cache.blob_by_uri.lock().await.len(), 3);
874        assert_eq!(cache.consensus_by_uri.lock().await.len(), 3);
875
876        // User info and port also matter, so we get new instances.
877        let _ = cache
878            .open(PersistLocation {
879                blob_uri: SensitiveUrl::from_str("mem://user@blob_one").expect("invalid URL"),
880                consensus_uri: SensitiveUrl::from_str("mem://@consensus_one:123")
881                    .expect("invalid URL"),
882            })
883            .await
884            .expect("failed to open location");
885        assert_eq!(cache.blob_by_uri.lock().await.len(), 4);
886        assert_eq!(cache.consensus_by_uri.lock().await.len(), 4);
887    }
888
889    #[mz_ore::test(tokio::test)]
890    #[cfg_attr(miri, ignore)] // unsupported operation: returning ready events from epoll_wait is not yet implemented
891    async fn state_cache() {
892        mz_ore::test::init_logging();
893        fn new_state<K, V, T, D>(shard_id: ShardId) -> TypedState<K, V, T, D>
894        where
895            K: Codec,
896            V: Codec,
897            T: Timestamp + Lattice + Codec64,
898            D: Codec64,
899        {
900            TypedState::new(
901                DUMMY_BUILD_INFO.semver_version(),
902                shard_id,
903                "host".into(),
904                0,
905            )
906        }
907        fn assert_same<K, V, T, D>(
908            state1: &LockingTypedState<K, V, T, D>,
909            state2: &LockingTypedState<K, V, T, D>,
910        ) {
911            let pointer1 = format!("{:p}", state1.state.read().expect("lock").deref());
912            let pointer2 = format!("{:p}", state2.state.read().expect("lock").deref());
913            assert_eq!(pointer1, pointer2);
914        }
915
916        let s1 = ShardId::new();
917        let states = Arc::new(StateCache::new_no_metrics());
918
919        // The cache starts empty.
920        assert_eq!(states.states.lock().expect("lock").len(), 0);
921
922        // Panic'ing during init_fn .
923        let s = Arc::clone(&states);
924        let res = spawn(|| "test", async move {
925            s.get::<(), (), u64, i64, _, _>(
926                s1,
927                || async { panic!("forced panic") },
928                &Diagnostics::for_tests(),
929            )
930            .await
931        })
932        .into_tokio_handle()
933        .await;
934        assert_err!(res);
935        assert_eq!(states.initialized_count(), 0);
936
937        // Returning an error from init_fn doesn't initialize an entry in the cache.
938        let res = states
939            .get::<(), (), u64, i64, _, _>(
940                s1,
941                || async {
942                    Err(Box::new(CodecMismatch {
943                        requested: ("".into(), "".into(), "".into(), "".into(), None),
944                        actual: ("".into(), "".into(), "".into(), "".into(), None),
945                    }))
946                },
947                &Diagnostics::for_tests(),
948            )
949            .await;
950        assert_err!(res);
951        assert_eq!(states.initialized_count(), 0);
952
953        // Initialize one shard.
954        let did_work = Arc::new(AtomicBool::new(false));
955        let s1_state1 = states
956            .get::<(), (), u64, i64, _, _>(
957                s1,
958                || {
959                    let did_work = Arc::clone(&did_work);
960                    async move {
961                        did_work.store(true, Ordering::SeqCst);
962                        Ok(new_state(s1))
963                    }
964                },
965                &Diagnostics::for_tests(),
966            )
967            .await
968            .expect("should successfully initialize");
969        assert_eq!(did_work.load(Ordering::SeqCst), true);
970        assert_eq!(states.initialized_count(), 1);
971        assert_eq!(states.strong_count(), 1);
972
973        // Trying to initialize it again does no work and returns the same state.
974        let did_work = Arc::new(AtomicBool::new(false));
975        let s1_state2 = states
976            .get::<(), (), u64, i64, _, _>(
977                s1,
978                || {
979                    let did_work = Arc::clone(&did_work);
980                    async move {
981                        did_work.store(true, Ordering::SeqCst);
982                        did_work.store(true, Ordering::SeqCst);
983                        Ok(new_state(s1))
984                    }
985                },
986                &Diagnostics::for_tests(),
987            )
988            .await
989            .expect("should successfully initialize");
990        assert_eq!(did_work.load(Ordering::SeqCst), false);
991        assert_eq!(states.initialized_count(), 1);
992        assert_eq!(states.strong_count(), 1);
993        assert_same(&s1_state1, &s1_state2);
994
995        // Trying to initialize with different types doesn't work.
996        let did_work = Arc::new(AtomicBool::new(false));
997        let res = states
998            .get::<String, (), u64, i64, _, _>(
999                s1,
1000                || {
1001                    let did_work = Arc::clone(&did_work);
1002                    async move {
1003                        did_work.store(true, Ordering::SeqCst);
1004                        Ok(new_state(s1))
1005                    }
1006                },
1007                &Diagnostics::for_tests(),
1008            )
1009            .await;
1010        assert_eq!(did_work.load(Ordering::SeqCst), false);
1011        assert_eq!(
1012            format!("{}", res.expect_err("types shouldn't match")),
1013            "requested codecs (\"String\", \"()\", \"u64\", \"i64\", Some(CodecConcreteType(\"(alloc::string::String, (), u64, i64)\"))) did not match ones in durable storage (\"()\", \"()\", \"u64\", \"i64\", Some(CodecConcreteType(\"((), (), u64, i64)\")))"
1014        );
1015        assert_eq!(states.initialized_count(), 1);
1016        assert_eq!(states.strong_count(), 1);
1017
1018        // We can add a shard of a different type.
1019        let s2 = ShardId::new();
1020        let s2_state1 = states
1021            .get::<String, (), u64, i64, _, _>(
1022                s2,
1023                || async { Ok(new_state(s2)) },
1024                &Diagnostics::for_tests(),
1025            )
1026            .await
1027            .expect("should successfully initialize");
1028        assert_eq!(states.initialized_count(), 2);
1029        assert_eq!(states.strong_count(), 2);
1030        let s2_state2 = states
1031            .get::<String, (), u64, i64, _, _>(
1032                s2,
1033                || async { Ok(new_state(s2)) },
1034                &Diagnostics::for_tests(),
1035            )
1036            .await
1037            .expect("should successfully initialize");
1038        assert_same(&s2_state1, &s2_state2);
1039
1040        // The cache holds weak references to State so we reclaim memory if the
1041        // shards stops being used.
1042        drop(s1_state1);
1043        assert_eq!(states.strong_count(), 2);
1044        drop(s1_state2);
1045        assert_eq!(states.strong_count(), 1);
1046        assert_eq!(states.initialized_count(), 2);
1047        assert_none!(states.get_cached(&s1));
1048
1049        // But we can re-init that shard if necessary.
1050        let s1_state1 = states
1051            .get::<(), (), u64, i64, _, _>(
1052                s1,
1053                || async { Ok(new_state(s1)) },
1054                &Diagnostics::for_tests(),
1055            )
1056            .await
1057            .expect("should successfully initialize");
1058        assert_eq!(states.initialized_count(), 2);
1059        assert_eq!(states.strong_count(), 2);
1060        drop(s1_state1);
1061        assert_eq!(states.strong_count(), 1);
1062    }
1063
1064    #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
1065    #[cfg_attr(miri, ignore)] // too slow
1066    async fn state_cache_concurrency() {
1067        mz_ore::test::init_logging();
1068
1069        const COUNT: usize = 1000;
1070        let id = ShardId::new();
1071        let cache = StateCache::new_no_metrics();
1072        let diagnostics = Diagnostics::for_tests();
1073
1074        let mut futures = (0..COUNT)
1075            .map(|_| {
1076                cache.get::<(), (), u64, i64, _, _>(
1077                    id,
1078                    || async {
1079                        Ok(TypedState::new(
1080                            DUMMY_BUILD_INFO.semver_version(),
1081                            id,
1082                            "host".into(),
1083                            0,
1084                        ))
1085                    },
1086                    &diagnostics,
1087                )
1088            })
1089            .collect::<FuturesUnordered<_>>();
1090
1091        for _ in 0..COUNT {
1092            let _ = futures.next().await.unwrap();
1093        }
1094    }
1095
1096    #[mz_ore::test(tokio::test)]
1097    #[cfg_attr(miri, ignore)] // too slow
1098    async fn update_semaphore() {
1099        // Check that the update lease mechanism is not susceptible to futurelock.
1100        // If there is an issue, this test will time out.
1101        mz_ore::test::init_logging();
1102
1103        let shard_id = ShardId::new();
1104        let persist_config = Arc::new(PersistConfig::new_for_tests());
1105        let pubsub = Arc::new(NoopPubSubSender);
1106        let state: LockingTypedState<String, (), u64, i64> = LockingTypedState::new(
1107            shard_id,
1108            TypedState::new(
1109                DUMMY_BUILD_INFO.semver_version(),
1110                shard_id,
1111                "host".into(),
1112                0,
1113            ),
1114            Arc::new(Metrics::new(&*persist_config, &MetricsRegistry::new())),
1115            persist_config,
1116            pubsub.subscribe(&shard_id),
1117            &Diagnostics::for_tests(),
1118        );
1119
1120        // Initialize three futures, all of which will grab a lease and then poll a oneshot,
1121        // which allows us to externally trigger which ones will complete.
1122        let mk_future = || {
1123            let (tx, rx) = oneshot::channel();
1124            let future = async {
1125                let lease = state.lease_for_update().await;
1126                let () = rx.await.unwrap();
1127                drop(lease);
1128            };
1129            (future, tx)
1130        };
1131
1132        let (one, _one_tx) = mk_future();
1133        let (two, _two_tx) = mk_future();
1134        let (three, three_tx) = mk_future();
1135        let mut one = pin!(one);
1136        let mut two = pin!(two);
1137        let mut three = pin!(three);
1138
1139        // Poll all the futures, but fall through to the default case, since none are ready.
1140        tokio::select! { biased;
1141            _ = &mut one => { unreachable!() }
1142            _ = &mut two => { unreachable!() }
1143            _ = &mut three => { unreachable!() }
1144            _ = async {} => {}
1145        }
1146
1147        // Allow the third future to complete.
1148        three_tx.send(()).unwrap();
1149
1150        // Poll all the futures but the second future. This shouldn't hang, since the third future
1151        // is now ready to go and the others should eventually time out.
1152        tokio::select! { biased;
1153            _ = &mut one => { unreachable!() }
1154            _ = &mut three => {  }
1155        }
1156    }
1157
1158    #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
1159    #[cfg_attr(miri, ignore)] // too slow
1160    async fn update_semaphore_stress() {
1161        // Check that the update lease mechanism is not susceptible to futurelock.
1162        // If there is an issue, this test will time out.
1163        mz_ore::test::init_logging();
1164
1165        const TIMEOUT: Duration = Duration::from_millis(100);
1166        const COUNT: u64 = 100;
1167
1168        let shard_id = ShardId::new();
1169        let persist_config = Arc::new(PersistConfig::new_for_tests());
1170        persist_config.set_config(&STATE_UPDATE_LEASE_TIMEOUT, TIMEOUT);
1171        let pubsub = Arc::new(NoopPubSubSender);
1172        let state: LockingTypedState<String, (), u64, i64> = LockingTypedState::new(
1173            shard_id,
1174            TypedState::new(
1175                DUMMY_BUILD_INFO.semver_version(),
1176                shard_id,
1177                "host".into(),
1178                0,
1179            ),
1180            Arc::new(Metrics::new(&*persist_config, &MetricsRegistry::new())),
1181            persist_config,
1182            pubsub.subscribe(&shard_id),
1183            &Diagnostics::for_tests(),
1184        );
1185
1186        let mut futures = (0..(COUNT * 3))
1187            .map(async |i| {
1188                state.lease_for_update().await;
1189                // Either hang forever, succeed quickly, or succeed after hitting the timeout.
1190                match i % 3 {
1191                    0 => {
1192                        let () = std::future::pending().await;
1193                    }
1194                    1 => {
1195                        tokio::time::sleep(Duration::from_millis(i)).await;
1196                    }
1197                    _ => {
1198                        tokio::time::sleep(Duration::from_millis(i) + TIMEOUT).await;
1199                    }
1200                }
1201            })
1202            .collect::<FuturesUnordered<_>>();
1203
1204        // All the futures that don't themselves hang forever should resolve.
1205        for _ in 0..(COUNT * 2) {
1206            futures.next().await.unwrap();
1207        }
1208    }
1209}