Skip to main content

mz_persist_client/
cache.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! A cache of [PersistClient]s indexed by [PersistLocation]s.
11
12use std::any::Any;
13use std::collections::BTreeMap;
14use std::collections::btree_map::Entry;
15use std::fmt::Debug;
16use std::future::Future;
17use std::sync::{Arc, RwLock, TryLockError, Weak};
18use std::time::{Duration, Instant};
19
20use differential_dataflow::difference::Monoid;
21use differential_dataflow::lattice::Lattice;
22use mz_dyncfg::Config;
23use mz_ore::instrument;
24use mz_ore::metrics::MetricsRegistry;
25use mz_ore::task::{AbortOnDropHandle, JoinHandle};
26use mz_ore::url::SensitiveUrl;
27use mz_persist::cfg::{BlobConfig, ConsensusConfig};
28use mz_persist::location::{
29    BLOB_GET_LIVENESS_KEY, Blob, CONSENSUS_HEAD_LIVENESS_KEY, Consensus, ExternalError, Tasked,
30    VersionedData,
31};
32use mz_persist_types::{Codec, Codec64};
33use timely::progress::Timestamp;
34use tokio::sync::{Mutex, OnceCell};
35use tracing::debug;
36
37use crate::async_runtime::IsolatedRuntime;
38use crate::error::{CodecConcreteType, CodecMismatch};
39use crate::internal::cache::BlobMemCache;
40use crate::internal::machine::retry_external;
41use crate::internal::metrics::{LockMetrics, Metrics, MetricsBlob, MetricsConsensus, ShardMetrics};
42use crate::internal::state::TypedState;
43use crate::internal::watch::{AwaitableState, StateWatchNotifier};
44use crate::rpc::{PubSubClientConnection, PubSubSender, ShardSubscriptionToken};
45use crate::schema::SchemaCacheMaps;
46use crate::{Diagnostics, PersistClient, PersistConfig, PersistLocation, ShardId};
47
48/// A cache of [PersistClient]s indexed by [PersistLocation]s.
49///
50/// There should be at most one of these per process. All production
51/// PersistClients should be created through this cache.
52///
53/// This is because, in production, persist is heavily limited by the number of
54/// server-side Postgres/Aurora connections. This cache allows PersistClients to
55/// share, for example, these Postgres connections.
56#[derive(Debug)]
57pub struct PersistClientCache {
58    /// The tunable knobs for persist.
59    pub cfg: PersistConfig,
60    pub(crate) metrics: Arc<Metrics>,
61    blob_by_uri: Mutex<BTreeMap<SensitiveUrl, (RttLatencyTask, Arc<dyn Blob>)>>,
62    consensus_by_uri: Mutex<BTreeMap<SensitiveUrl, (RttLatencyTask, Arc<dyn Consensus>)>>,
63    isolated_runtime: Arc<IsolatedRuntime>,
64    pub(crate) state_cache: Arc<StateCache>,
65    pubsub_sender: Arc<dyn PubSubSender>,
66    _pubsub_receiver_task: JoinHandle<()>,
67}
68
69#[derive(Debug)]
70struct RttLatencyTask(#[allow(dead_code)] AbortOnDropHandle<()>);
71
72impl PersistClientCache {
73    /// Returns a new [PersistClientCache].
74    pub fn new<F>(cfg: PersistConfig, registry: &MetricsRegistry, pubsub: F) -> Self
75    where
76        F: FnOnce(&PersistConfig, Arc<Metrics>) -> PubSubClientConnection,
77    {
78        let metrics = Arc::new(Metrics::new(&cfg, registry));
79        let pubsub_client = pubsub(&cfg, Arc::clone(&metrics));
80
81        let state_cache = Arc::new(StateCache::new(
82            &cfg,
83            Arc::clone(&metrics),
84            Arc::clone(&pubsub_client.sender),
85        ));
86        let _pubsub_receiver_task = crate::rpc::subscribe_state_cache_to_pubsub(
87            Arc::clone(&state_cache),
88            pubsub_client.receiver,
89        );
90        let isolated_runtime =
91            IsolatedRuntime::new(registry, Some(cfg.isolated_runtime_worker_threads));
92
93        PersistClientCache {
94            cfg,
95            metrics,
96            blob_by_uri: Mutex::new(BTreeMap::new()),
97            consensus_by_uri: Mutex::new(BTreeMap::new()),
98            isolated_runtime: Arc::new(isolated_runtime),
99            state_cache,
100            pubsub_sender: pubsub_client.sender,
101            _pubsub_receiver_task,
102        }
103    }
104
105    /// A test helper that returns a [PersistClientCache] disconnected from
106    /// metrics.
107    pub fn new_no_metrics() -> Self {
108        Self::new(
109            PersistConfig::new_for_tests(),
110            &MetricsRegistry::new(),
111            |_, _| PubSubClientConnection::noop(),
112        )
113    }
114
115    #[cfg(feature = "turmoil")]
116    /// Create a [PersistClientCache] for use in turmoil tests.
117    ///
118    /// Turmoil wants to run all software under test in a single thread, so we disable the
119    /// (multi-threaded) isolated runtime.
120    pub fn new_for_turmoil() -> Self {
121        use crate::rpc::NoopPubSubSender;
122
123        let cfg = PersistConfig::new_for_tests();
124        let metrics = Arc::new(Metrics::new(&cfg, &MetricsRegistry::new()));
125
126        let pubsub_sender: Arc<dyn PubSubSender> = Arc::new(NoopPubSubSender);
127        let _pubsub_receiver_task = mz_ore::task::spawn(|| "noop", async {});
128
129        let state_cache = Arc::new(StateCache::new(
130            &cfg,
131            Arc::clone(&metrics),
132            Arc::clone(&pubsub_sender),
133        ));
134        let isolated_runtime = IsolatedRuntime::new_disabled();
135
136        PersistClientCache {
137            cfg,
138            metrics,
139            blob_by_uri: Mutex::new(BTreeMap::new()),
140            consensus_by_uri: Mutex::new(BTreeMap::new()),
141            isolated_runtime: Arc::new(isolated_runtime),
142            state_cache,
143            pubsub_sender,
144            _pubsub_receiver_task,
145        }
146    }
147
148    /// Returns the [PersistConfig] being used by this cache.
149    pub fn cfg(&self) -> &PersistConfig {
150        &self.cfg
151    }
152
153    /// Returns persist `Metrics`.
154    pub fn metrics(&self) -> &Arc<Metrics> {
155        &self.metrics
156    }
157
158    /// Returns `ShardMetrics` for the given shard.
159    pub fn shard_metrics(&self, shard_id: &ShardId, name: &str) -> Arc<ShardMetrics> {
160        self.metrics.shards.shard(shard_id, name)
161    }
162
163    /// Clears the state cache, allowing for tests with disconnected states.
164    ///
165    /// Only exposed for testing.
166    pub fn clear_state_cache(&mut self) {
167        self.state_cache = Arc::new(StateCache::new(
168            &self.cfg,
169            Arc::clone(&self.metrics),
170            Arc::clone(&self.pubsub_sender),
171        ))
172    }
173
174    /// Returns a new [PersistClient] for interfacing with persist shards made
175    /// durable to the given [PersistLocation].
176    ///
177    /// The same `location` may be used concurrently from multiple processes.
178    #[instrument(level = "debug")]
179    pub async fn open(&self, location: PersistLocation) -> Result<PersistClient, ExternalError> {
180        let blob = self.open_blob(location.blob_uri).await?;
181        let consensus = self.open_consensus(location.consensus_uri).await?;
182        PersistClient::new(
183            self.cfg.clone(),
184            blob,
185            consensus,
186            Arc::clone(&self.metrics),
187            Arc::clone(&self.isolated_runtime),
188            Arc::clone(&self.state_cache),
189            Arc::clone(&self.pubsub_sender),
190        )
191    }
192
193    // No sense in measuring rtt latencies more often than this.
194    const PROMETHEUS_SCRAPE_INTERVAL: Duration = Duration::from_secs(60);
195
196    async fn open_consensus(
197        &self,
198        consensus_uri: SensitiveUrl,
199    ) -> Result<Arc<dyn Consensus>, ExternalError> {
200        let mut consensus_by_uri = self.consensus_by_uri.lock().await;
201        let consensus = match consensus_by_uri.entry(consensus_uri) {
202            Entry::Occupied(x) => Arc::clone(&x.get().1),
203            Entry::Vacant(x) => {
204                // Intentionally hold the lock, so we don't double connect under
205                // concurrency.
206                let consensus = ConsensusConfig::try_from(
207                    x.key(),
208                    Box::new(self.cfg.clone()),
209                    self.metrics.postgres_consensus.clone(),
210                    Arc::clone(&self.cfg().configs),
211                )?;
212                let consensus =
213                    retry_external(&self.metrics.retries.external.consensus_open, || {
214                        consensus.clone().open()
215                    })
216                    .await;
217                let consensus =
218                    Arc::new(MetricsConsensus::new(consensus, Arc::clone(&self.metrics)));
219                let consensus = Arc::new(Tasked(consensus));
220                let task = consensus_rtt_latency_task(
221                    Arc::clone(&consensus),
222                    Arc::clone(&self.metrics),
223                    Self::PROMETHEUS_SCRAPE_INTERVAL,
224                )
225                .await;
226                Arc::clone(
227                    &x.insert((RttLatencyTask(task.abort_on_drop()), consensus))
228                        .1,
229                )
230            }
231        };
232        Ok(consensus)
233    }
234
235    async fn open_blob(&self, blob_uri: SensitiveUrl) -> Result<Arc<dyn Blob>, ExternalError> {
236        let mut blob_by_uri = self.blob_by_uri.lock().await;
237        let blob = match blob_by_uri.entry(blob_uri) {
238            Entry::Occupied(x) => Arc::clone(&x.get().1),
239            Entry::Vacant(x) => {
240                // Intentionally hold the lock, so we don't double connect under
241                // concurrency.
242                let blob = BlobConfig::try_from(
243                    x.key(),
244                    Box::new(self.cfg.clone()),
245                    self.metrics.s3_blob.clone(),
246                    Arc::clone(&self.cfg.configs),
247                )
248                .await?;
249                let blob = retry_external(&self.metrics.retries.external.blob_open, || {
250                    blob.clone().open()
251                })
252                .await;
253                let blob = Arc::new(MetricsBlob::new(blob, Arc::clone(&self.metrics)));
254                let blob = Arc::new(Tasked(blob));
255                let task = blob_rtt_latency_task(
256                    Arc::clone(&blob),
257                    Arc::clone(&self.metrics),
258                    Self::PROMETHEUS_SCRAPE_INTERVAL,
259                )
260                .await;
261                // This is intentionally "outside" (wrapping) MetricsBlob so
262                // that we don't include cached responses in blob metrics.
263                let blob = BlobMemCache::new(&self.cfg, Arc::clone(&self.metrics), blob);
264                Arc::clone(&x.insert((RttLatencyTask(task.abort_on_drop()), blob)).1)
265            }
266        };
267        Ok(blob)
268    }
269}
270
271/// Starts a task to periodically measure the persist-observed latency to
272/// consensus.
273///
274/// This is a task, rather than something like looking at the latencies of prod
275/// traffic, so that we minimize any issues around Futures not being polled
276/// promptly (as can and does happen with the Timely-polled Futures).
277///
278/// The caller is responsible for shutdown via aborting the `JoinHandle`.
279///
280/// No matter whether we wrap MetricsConsensus before or after we start up the
281/// rtt latency task, there's the possibility for it being confusing at some
282/// point. Err on the side of more data (including the latency measurements) to
283/// start.
284#[allow(clippy::unused_async)]
285async fn blob_rtt_latency_task(
286    blob: Arc<Tasked<MetricsBlob>>,
287    metrics: Arc<Metrics>,
288    measurement_interval: Duration,
289) -> JoinHandle<()> {
290    mz_ore::task::spawn(|| "persist::blob_rtt_latency", async move {
291        // Use the tokio Instant for next_measurement because the reclock tests
292        // mess with the tokio sleep clock.
293        let mut next_measurement = tokio::time::Instant::now();
294        loop {
295            tokio::time::sleep_until(next_measurement).await;
296            let start = Instant::now();
297            match blob.get(BLOB_GET_LIVENESS_KEY).await {
298                Ok(_) => {
299                    metrics.blob.rtt_latency.set(start.elapsed().as_secs_f64());
300                }
301                Err(_) => {
302                    // Don't spam retries if this returns an error. We're
303                    // guaranteed by the method signature that we've already got
304                    // metrics coverage of these, so we'll count the errors.
305                }
306            }
307            next_measurement = tokio::time::Instant::now() + measurement_interval;
308        }
309    })
310}
311
312/// Starts a task to periodically measure the persist-observed latency to
313/// consensus.
314///
315/// This is a task, rather than something like looking at the latencies of prod
316/// traffic, so that we minimize any issues around Futures not being polled
317/// promptly (as can and does happen with the Timely-polled Futures).
318///
319/// The caller is responsible for shutdown via aborting the `JoinHandle`.
320///
321/// No matter whether we wrap MetricsConsensus before or after we start up the
322/// rtt latency task, there's the possibility for it being confusing at some
323/// point. Err on the side of more data (including the latency measurements) to
324/// start.
325#[allow(clippy::unused_async)]
326async fn consensus_rtt_latency_task(
327    consensus: Arc<Tasked<MetricsConsensus>>,
328    metrics: Arc<Metrics>,
329    measurement_interval: Duration,
330) -> JoinHandle<()> {
331    mz_ore::task::spawn(|| "persist::consensus_rtt_latency", async move {
332        // Use the tokio Instant for next_measurement because the reclock tests
333        // mess with the tokio sleep clock.
334        let mut next_measurement = tokio::time::Instant::now();
335        loop {
336            tokio::time::sleep_until(next_measurement).await;
337            let start = Instant::now();
338            match consensus.head(CONSENSUS_HEAD_LIVENESS_KEY).await {
339                Ok(_) => {
340                    metrics
341                        .consensus
342                        .rtt_latency
343                        .set(start.elapsed().as_secs_f64());
344                }
345                Err(_) => {
346                    // Don't spam retries if this returns an error. We're
347                    // guaranteed by the method signature that we've already got
348                    // metrics coverage of these, so we'll count the errors.
349                }
350            }
351            next_measurement = tokio::time::Instant::now() + measurement_interval;
352        }
353    })
354}
355
356pub(crate) trait DynState: Debug + Send + Sync {
357    fn codecs(&self) -> (String, String, String, String, Option<CodecConcreteType>);
358    fn as_any(self: Arc<Self>) -> Arc<dyn Any + Send + Sync>;
359    fn push_diff(&self, diff: VersionedData);
360}
361
362impl<K, V, T, D> DynState for LockingTypedState<K, V, T, D>
363where
364    K: Codec,
365    V: Codec,
366    T: Timestamp + Lattice + Codec64 + Sync,
367    D: Codec64,
368{
369    fn codecs(&self) -> (String, String, String, String, Option<CodecConcreteType>) {
370        (
371            K::codec_name(),
372            V::codec_name(),
373            T::codec_name(),
374            D::codec_name(),
375            Some(CodecConcreteType(std::any::type_name::<(K, V, T, D)>())),
376        )
377    }
378
379    fn as_any(self: Arc<Self>) -> Arc<dyn Any + Send + Sync> {
380        self
381    }
382
383    fn push_diff(&self, diff: VersionedData) {
384        self.write_lock(&self.metrics.locks.applier_write, |state| {
385            let seqno_before = state.seqno;
386            state.apply_encoded_diffs(&self.cfg, &self.metrics, std::iter::once(&diff));
387            let seqno_after = state.seqno;
388            assert!(seqno_after >= seqno_before);
389
390            if seqno_before != seqno_after {
391                debug!(
392                    "applied pushed diff {}. seqno {} -> {}.",
393                    state.shard_id, seqno_before, state.seqno
394                );
395                self.shard_metrics.pubsub_push_diff_applied.inc();
396            } else {
397                debug!(
398                    "failed to apply pushed diff {}. seqno {} vs diff {}",
399                    state.shard_id, seqno_before, diff.seqno
400                );
401                if diff.seqno <= seqno_before {
402                    self.shard_metrics.pubsub_push_diff_not_applied_stale.inc();
403                } else {
404                    self.shard_metrics
405                        .pubsub_push_diff_not_applied_out_of_order
406                        .inc();
407                }
408            }
409        })
410    }
411}
412
413/// A cache of `TypedState`, shared between all machines for that shard.
414///
415/// This is shared between all machines that come out of the same
416/// [PersistClientCache], but in production there is one of those per process,
417/// so in practice, we have one copy of state per shard per process.
418///
419/// The mutex contention between commands is not an issue, because if two
420/// command for the same shard are executing concurrently, only one can win
421/// anyway, the other will retry. With the mutex, we even get to avoid the retry
422/// if the racing commands are on the same process.
423#[derive(Debug)]
424pub struct StateCache {
425    cfg: Arc<PersistConfig>,
426    pub(crate) metrics: Arc<Metrics>,
427    states: Arc<std::sync::Mutex<BTreeMap<ShardId, Arc<OnceCell<Weak<dyn DynState>>>>>>,
428    pubsub_sender: Arc<dyn PubSubSender>,
429}
430
431#[derive(Debug)]
432enum StateCacheInit {
433    Init(Arc<dyn DynState>),
434    NeedInit(Arc<OnceCell<Weak<dyn DynState>>>),
435}
436
437impl StateCache {
438    /// Returns a new StateCache.
439    pub fn new(
440        cfg: &PersistConfig,
441        metrics: Arc<Metrics>,
442        pubsub_sender: Arc<dyn PubSubSender>,
443    ) -> Self {
444        StateCache {
445            cfg: Arc::new(cfg.clone()),
446            metrics,
447            states: Default::default(),
448            pubsub_sender,
449        }
450    }
451
452    #[cfg(test)]
453    pub(crate) fn new_no_metrics() -> Self {
454        Self::new(
455            &PersistConfig::new_for_tests(),
456            Arc::new(Metrics::new(
457                &PersistConfig::new_for_tests(),
458                &MetricsRegistry::new(),
459            )),
460            Arc::new(crate::rpc::NoopPubSubSender),
461        )
462    }
463
464    pub(crate) async fn get<K, V, T, D, F, InitFn>(
465        &self,
466        shard_id: ShardId,
467        mut init_fn: InitFn,
468        diagnostics: &Diagnostics,
469    ) -> Result<Arc<LockingTypedState<K, V, T, D>>, Box<CodecMismatch>>
470    where
471        K: Debug + Codec,
472        V: Debug + Codec,
473        T: Timestamp + Lattice + Codec64 + Sync,
474        D: Monoid + Codec64,
475        F: Future<Output = Result<TypedState<K, V, T, D>, Box<CodecMismatch>>>,
476        InitFn: FnMut() -> F,
477    {
478        loop {
479            let init = {
480                let mut states = self.states.lock().expect("lock poisoned");
481                let state = states.entry(shard_id).or_default();
482                match state.get() {
483                    Some(once_val) => match once_val.upgrade() {
484                        Some(x) => StateCacheInit::Init(x),
485                        None => {
486                            // If the Weak has lost the ability to upgrade,
487                            // we've dropped the State and it's gone. Clear the
488                            // OnceCell and init a new one.
489                            *state = Arc::new(OnceCell::new());
490                            StateCacheInit::NeedInit(Arc::clone(state))
491                        }
492                    },
493                    None => StateCacheInit::NeedInit(Arc::clone(state)),
494                }
495            };
496
497            let state = match init {
498                StateCacheInit::Init(x) => x,
499                StateCacheInit::NeedInit(init_once) => {
500                    let mut did_init: Option<Arc<LockingTypedState<K, V, T, D>>> = None;
501                    let state = init_once
502                        .get_or_try_init::<Box<CodecMismatch>, _, _>(|| async {
503                            let init_res = init_fn().await;
504                            let state = Arc::new(LockingTypedState::new(
505                                shard_id,
506                                init_res?,
507                                Arc::clone(&self.metrics),
508                                Arc::clone(&self.cfg),
509                                Arc::clone(&self.pubsub_sender).subscribe(&shard_id),
510                                diagnostics,
511                            ));
512                            let ret = Arc::downgrade(&state);
513                            did_init = Some(state);
514                            let ret: Weak<dyn DynState> = ret;
515                            Ok(ret)
516                        })
517                        .await?;
518                    if let Some(x) = did_init {
519                        // We actually did the init work, don't bother casting back
520                        // the type erased and weak version. Additionally, inform
521                        // any listeners of this new state.
522                        return Ok(x);
523                    }
524                    let Some(state) = state.upgrade() else {
525                        // Race condition. Between when we first checked the
526                        // OnceCell and the `get_or_try_init` call, (1) the
527                        // initialization finished, (2) the other user dropped
528                        // the strong ref, and (3) the Arc noticed it was down
529                        // to only weak refs and dropped the value. Nothing we
530                        // can do except try again.
531                        continue;
532                    };
533                    state
534                }
535            };
536
537            match Arc::clone(&state)
538                .as_any()
539                .downcast::<LockingTypedState<K, V, T, D>>()
540            {
541                Ok(x) => return Ok(x),
542                Err(_) => {
543                    return Err(Box::new(CodecMismatch {
544                        requested: (
545                            K::codec_name(),
546                            V::codec_name(),
547                            T::codec_name(),
548                            D::codec_name(),
549                            Some(CodecConcreteType(std::any::type_name::<(K, V, T, D)>())),
550                        ),
551                        actual: state.codecs(),
552                    }));
553                }
554            }
555        }
556    }
557
558    pub(crate) fn get_state_weak(&self, shard_id: &ShardId) -> Option<Weak<dyn DynState>> {
559        self.states
560            .lock()
561            .expect("lock")
562            .get(shard_id)
563            .and_then(|x| x.get())
564            .map(Weak::clone)
565    }
566
567    #[cfg(test)]
568    fn get_cached(&self, shard_id: &ShardId) -> Option<Arc<dyn DynState>> {
569        self.states
570            .lock()
571            .expect("lock")
572            .get(shard_id)
573            .and_then(|x| x.get())
574            .and_then(|x| x.upgrade())
575    }
576
577    #[cfg(test)]
578    fn initialized_count(&self) -> usize {
579        self.states
580            .lock()
581            .expect("lock")
582            .values()
583            .filter(|x| x.initialized())
584            .count()
585    }
586
587    #[cfg(test)]
588    fn strong_count(&self) -> usize {
589        self.states
590            .lock()
591            .expect("lock")
592            .values()
593            .filter(|x| x.get().map_or(false, |x| x.upgrade().is_some()))
594            .count()
595    }
596}
597
598/// A locked decorator for TypedState that abstracts out the specific lock implementation used.
599/// Guards the private lock with public accessor fns to make locking scopes more explicit and
600/// simpler to reason about.
601pub(crate) struct LockingTypedState<K, V, T, D> {
602    shard_id: ShardId,
603    state: RwLock<TypedState<K, V, T, D>>,
604    notifier: StateWatchNotifier,
605    cfg: Arc<PersistConfig>,
606    metrics: Arc<Metrics>,
607    shard_metrics: Arc<ShardMetrics>,
608    update_semaphore: AwaitableState<Option<tokio::time::Instant>>,
609    /// A [SchemaCacheMaps<K, V>], but stored as an Any so the `: Codec` bounds
610    /// don't propagate to basically every struct in persist.
611    schema_cache: Arc<dyn Any + Send + Sync>,
612    _subscription_token: Arc<ShardSubscriptionToken>,
613}
614
615impl<K, V, T: Debug, D> Debug for LockingTypedState<K, V, T, D> {
616    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
617        let LockingTypedState {
618            shard_id,
619            state,
620            notifier,
621            cfg: _cfg,
622            metrics: _metrics,
623            shard_metrics: _shard_metrics,
624            update_semaphore: _,
625            schema_cache: _schema_cache,
626            _subscription_token,
627        } = self;
628        f.debug_struct("LockingTypedState")
629            .field("shard_id", shard_id)
630            .field("state", state)
631            .field("notifier", notifier)
632            .finish()
633    }
634}
635
636impl<K: Codec, V: Codec, T, D> LockingTypedState<K, V, T, D> {
637    fn new(
638        shard_id: ShardId,
639        initial_state: TypedState<K, V, T, D>,
640        metrics: Arc<Metrics>,
641        cfg: Arc<PersistConfig>,
642        subscription_token: Arc<ShardSubscriptionToken>,
643        diagnostics: &Diagnostics,
644    ) -> Self {
645        Self {
646            shard_id,
647            notifier: StateWatchNotifier::new(Arc::clone(&metrics)),
648            state: RwLock::new(initial_state),
649            cfg: Arc::clone(&cfg),
650            shard_metrics: metrics.shards.shard(&shard_id, &diagnostics.shard_name),
651            update_semaphore: AwaitableState::new(None),
652            schema_cache: Arc::new(SchemaCacheMaps::<K, V>::new(&metrics.schema)),
653            metrics,
654            _subscription_token: subscription_token,
655        }
656    }
657
658    pub(crate) fn schema_cache(&self) -> Arc<SchemaCacheMaps<K, V>> {
659        Arc::clone(&self.schema_cache)
660            .downcast::<SchemaCacheMaps<K, V>>()
661            .expect("K and V match")
662    }
663}
664
665pub(crate) const STATE_UPDATE_LEASE_TIMEOUT: Config<Duration> = Config::new(
666    "persist_state_update_lease_timeout",
667    Duration::from_secs(1),
668    "The amount of time for a command to wait for a previous command to finish before executing. \
669        (If zero, commands will not wait for others to complete.) Higher values reduce database contention \
670        at the cost of higher worst-case latencies for individual requests.",
671);
672
673impl<K, V, T, D> LockingTypedState<K, V, T, D> {
674    pub(crate) fn shard_id(&self) -> &ShardId {
675        &self.shard_id
676    }
677
678    pub(crate) fn read_lock<R, F: FnMut(&TypedState<K, V, T, D>) -> R>(
679        &self,
680        metrics: &LockMetrics,
681        mut f: F,
682    ) -> R {
683        metrics.acquire_count.inc();
684        let state = match self.state.try_read() {
685            Ok(x) => x,
686            Err(TryLockError::WouldBlock) => {
687                metrics.blocking_acquire_count.inc();
688                let start = Instant::now();
689                let state = self.state.read().expect("lock poisoned");
690                metrics
691                    .blocking_seconds
692                    .inc_by(start.elapsed().as_secs_f64());
693                state
694            }
695            Err(TryLockError::Poisoned(err)) => panic!("state read lock poisoned: {}", err),
696        };
697        f(&state)
698    }
699
700    pub(crate) fn write_lock<R, F: FnOnce(&mut TypedState<K, V, T, D>) -> R>(
701        &self,
702        metrics: &LockMetrics,
703        f: F,
704    ) -> R {
705        metrics.acquire_count.inc();
706        let mut state = match self.state.try_write() {
707            Ok(x) => x,
708            Err(TryLockError::WouldBlock) => {
709                metrics.blocking_acquire_count.inc();
710                let start = Instant::now();
711                let state = self.state.write().expect("lock poisoned");
712                metrics
713                    .blocking_seconds
714                    .inc_by(start.elapsed().as_secs_f64());
715                state
716            }
717            Err(TryLockError::Poisoned(err)) => panic!("state read lock poisoned: {}", err),
718        };
719        let seqno_before = state.seqno;
720        let ret = f(&mut state);
721        let seqno_after = state.seqno;
722        debug_assert!(seqno_after >= seqno_before);
723        if seqno_after > seqno_before {
724            self.notifier.notify(seqno_after);
725        }
726        // For now, make sure to notify while under lock. It's possible to move
727        // this out of the lock window, see [StateWatchNotifier::notify].
728        drop(state);
729        ret
730    }
731
732    /// We want to _mostly_ just attempt a single CaS against the same state at once, since
733    /// only one concurrent CaS can succeed. However, we also want to guard against a
734    /// single hung update blocking all progress globally. We manage this with a shared state,
735    /// tracking whether a request is in flight and when it times out. If the timeout is never hit,
736    /// this behaves like a semaphore with limit 1... but if our requests _are_ timing out, future
737    /// requests will only wait for a bounded time before retrying, and one of those retries will
738    /// be able to claim that lease and make progress.
739    pub(crate) async fn lease_for_update(&self) -> impl Drop {
740        use tokio::time::Instant;
741
742        let timeout = STATE_UPDATE_LEASE_TIMEOUT.get(&self.cfg);
743
744        struct DropLease(Option<(AwaitableState<Option<Instant>>, Instant)>);
745
746        impl Drop for DropLease {
747            fn drop(&mut self) {
748                if let Some((state, time)) = self.0.take() {
749                    // Clear the timeout if it hasn't changed since we set it.
750                    state.maybe_modify(|s| {
751                        if s.is_some_and(|t| t == time) {
752                            *s.get_mut() = None;
753                        }
754                    })
755                }
756            }
757        }
758
759        // Special case: if the timeout is set to zero, go ahead without taking a lease.
760        if timeout.is_zero() {
761            return DropLease(None);
762        }
763
764        let timeout_state = self.update_semaphore.clone();
765        loop {
766            let now = tokio::time::Instant::now();
767            let expires_at = now + timeout;
768            // Claim the lease if there isn't one, or if the current lease has expired.
769            let maybe_leased = timeout_state.maybe_modify(|state| {
770                if let Some(other_expires_at) = **state
771                    && other_expires_at > now
772                {
773                    // Still locked: sleep until the deadline and try again.
774                    Err(other_expires_at)
775                } else {
776                    *state.get_mut() = Some(expires_at);
777                    Ok(())
778                }
779            });
780
781            match maybe_leased {
782                Ok(()) => {
783                    break DropLease(Some((timeout_state, expires_at)));
784                }
785                Err(other_expires_at) => {
786                    // Wait until either the lease has dropped or timed out, whichever is first.
787                    // If there are a lot of clients trying to update the same state, this may
788                    // cause significant lock contention... but the lock is only briefly held,
789                    // and anyways that's still cheaper than contending on the remote database.
790                    let _ = tokio::time::timeout_at(
791                        other_expires_at,
792                        timeout_state.wait_while(|s| s.is_some()),
793                    )
794                    .await;
795                }
796            }
797        }
798    }
799
800    pub(crate) fn notifier(&self) -> &StateWatchNotifier {
801        &self.notifier
802    }
803}
804
805#[cfg(test)]
806mod tests {
807    use std::ops::Deref;
808    use std::pin::pin;
809    use std::str::FromStr;
810    use std::sync::atomic::{AtomicBool, Ordering};
811
812    use super::*;
813    use crate::rpc::NoopPubSubSender;
814    use futures::stream::{FuturesUnordered, StreamExt};
815    use mz_build_info::DUMMY_BUILD_INFO;
816    use mz_ore::task::spawn;
817    use mz_ore::{assert_err, assert_none};
818    use tokio::sync::oneshot;
819
820    #[mz_ore::test(tokio::test)]
821    #[cfg_attr(miri, ignore)] // unsupported operation: returning ready events from epoll_wait is not yet implemented
822    async fn client_cache() {
823        let cache = PersistClientCache::new(
824            PersistConfig::new_for_tests(),
825            &MetricsRegistry::new(),
826            |_, _| PubSubClientConnection::noop(),
827        );
828        assert_eq!(cache.blob_by_uri.lock().await.len(), 0);
829        assert_eq!(cache.consensus_by_uri.lock().await.len(), 0);
830
831        // Opening a location on an empty cache saves the results.
832        let _ = cache
833            .open(PersistLocation {
834                blob_uri: SensitiveUrl::from_str("mem://blob_zero").expect("invalid URL"),
835                consensus_uri: SensitiveUrl::from_str("mem://consensus_zero").expect("invalid URL"),
836            })
837            .await
838            .expect("failed to open location");
839        assert_eq!(cache.blob_by_uri.lock().await.len(), 1);
840        assert_eq!(cache.consensus_by_uri.lock().await.len(), 1);
841
842        // Opening a location with an already opened consensus reuses it, even
843        // if the blob is different.
844        let _ = cache
845            .open(PersistLocation {
846                blob_uri: SensitiveUrl::from_str("mem://blob_one").expect("invalid URL"),
847                consensus_uri: SensitiveUrl::from_str("mem://consensus_zero").expect("invalid URL"),
848            })
849            .await
850            .expect("failed to open location");
851        assert_eq!(cache.blob_by_uri.lock().await.len(), 2);
852        assert_eq!(cache.consensus_by_uri.lock().await.len(), 1);
853
854        // Ditto the other way.
855        let _ = cache
856            .open(PersistLocation {
857                blob_uri: SensitiveUrl::from_str("mem://blob_one").expect("invalid URL"),
858                consensus_uri: SensitiveUrl::from_str("mem://consensus_one").expect("invalid URL"),
859            })
860            .await
861            .expect("failed to open location");
862        assert_eq!(cache.blob_by_uri.lock().await.len(), 2);
863        assert_eq!(cache.consensus_by_uri.lock().await.len(), 2);
864
865        // Query params and path matter, so we get new instances.
866        let _ = cache
867            .open(PersistLocation {
868                blob_uri: SensitiveUrl::from_str("mem://blob_one?foo").expect("invalid URL"),
869                consensus_uri: SensitiveUrl::from_str("mem://consensus_one/bar")
870                    .expect("invalid URL"),
871            })
872            .await
873            .expect("failed to open location");
874        assert_eq!(cache.blob_by_uri.lock().await.len(), 3);
875        assert_eq!(cache.consensus_by_uri.lock().await.len(), 3);
876
877        // User info and port also matter, so we get new instances.
878        let _ = cache
879            .open(PersistLocation {
880                blob_uri: SensitiveUrl::from_str("mem://user@blob_one").expect("invalid URL"),
881                consensus_uri: SensitiveUrl::from_str("mem://@consensus_one:123")
882                    .expect("invalid URL"),
883            })
884            .await
885            .expect("failed to open location");
886        assert_eq!(cache.blob_by_uri.lock().await.len(), 4);
887        assert_eq!(cache.consensus_by_uri.lock().await.len(), 4);
888    }
889
890    #[mz_ore::test(tokio::test)]
891    #[cfg_attr(miri, ignore)] // unsupported operation: returning ready events from epoll_wait is not yet implemented
892    async fn state_cache() {
893        mz_ore::test::init_logging();
894        fn new_state<K, V, T, D>(shard_id: ShardId) -> TypedState<K, V, T, D>
895        where
896            K: Codec,
897            V: Codec,
898            T: Timestamp + Lattice + Codec64,
899            D: Codec64,
900        {
901            TypedState::new(
902                DUMMY_BUILD_INFO.semver_version(),
903                shard_id,
904                "host".into(),
905                0,
906            )
907        }
908        fn assert_same<K, V, T, D>(
909            state1: &LockingTypedState<K, V, T, D>,
910            state2: &LockingTypedState<K, V, T, D>,
911        ) {
912            let pointer1 = format!("{:p}", state1.state.read().expect("lock").deref());
913            let pointer2 = format!("{:p}", state2.state.read().expect("lock").deref());
914            assert_eq!(pointer1, pointer2);
915        }
916
917        let s1 = ShardId::new();
918        let states = Arc::new(StateCache::new_no_metrics());
919
920        // The cache starts empty.
921        assert_eq!(states.states.lock().expect("lock").len(), 0);
922
923        // Panic'ing during init_fn .
924        let s = Arc::clone(&states);
925        let res = spawn(|| "test", async move {
926            s.get::<(), (), u64, i64, _, _>(
927                s1,
928                || async { panic!("forced panic") },
929                &Diagnostics::for_tests(),
930            )
931            .await
932        })
933        .into_tokio_handle()
934        .await;
935        assert_err!(res);
936        assert_eq!(states.initialized_count(), 0);
937
938        // Returning an error from init_fn doesn't initialize an entry in the cache.
939        let res = states
940            .get::<(), (), u64, i64, _, _>(
941                s1,
942                || async {
943                    Err(Box::new(CodecMismatch {
944                        requested: ("".into(), "".into(), "".into(), "".into(), None),
945                        actual: ("".into(), "".into(), "".into(), "".into(), None),
946                    }))
947                },
948                &Diagnostics::for_tests(),
949            )
950            .await;
951        assert_err!(res);
952        assert_eq!(states.initialized_count(), 0);
953
954        // Initialize one shard.
955        let did_work = Arc::new(AtomicBool::new(false));
956        let s1_state1 = states
957            .get::<(), (), u64, i64, _, _>(
958                s1,
959                || {
960                    let did_work = Arc::clone(&did_work);
961                    async move {
962                        did_work.store(true, Ordering::SeqCst);
963                        Ok(new_state(s1))
964                    }
965                },
966                &Diagnostics::for_tests(),
967            )
968            .await
969            .expect("should successfully initialize");
970        assert_eq!(did_work.load(Ordering::SeqCst), true);
971        assert_eq!(states.initialized_count(), 1);
972        assert_eq!(states.strong_count(), 1);
973
974        // Trying to initialize it again does no work and returns the same state.
975        let did_work = Arc::new(AtomicBool::new(false));
976        let s1_state2 = states
977            .get::<(), (), u64, i64, _, _>(
978                s1,
979                || {
980                    let did_work = Arc::clone(&did_work);
981                    async move {
982                        did_work.store(true, Ordering::SeqCst);
983                        did_work.store(true, Ordering::SeqCst);
984                        Ok(new_state(s1))
985                    }
986                },
987                &Diagnostics::for_tests(),
988            )
989            .await
990            .expect("should successfully initialize");
991        assert_eq!(did_work.load(Ordering::SeqCst), false);
992        assert_eq!(states.initialized_count(), 1);
993        assert_eq!(states.strong_count(), 1);
994        assert_same(&s1_state1, &s1_state2);
995
996        // Trying to initialize with different types doesn't work.
997        let did_work = Arc::new(AtomicBool::new(false));
998        let res = states
999            .get::<String, (), u64, i64, _, _>(
1000                s1,
1001                || {
1002                    let did_work = Arc::clone(&did_work);
1003                    async move {
1004                        did_work.store(true, Ordering::SeqCst);
1005                        Ok(new_state(s1))
1006                    }
1007                },
1008                &Diagnostics::for_tests(),
1009            )
1010            .await;
1011        assert_eq!(did_work.load(Ordering::SeqCst), false);
1012        assert_eq!(
1013            format!("{}", res.expect_err("types shouldn't match")),
1014            "requested codecs (\"String\", \"()\", \"u64\", \"i64\", Some(CodecConcreteType(\"(alloc::string::String, (), u64, i64)\"))) did not match ones in durable storage (\"()\", \"()\", \"u64\", \"i64\", Some(CodecConcreteType(\"((), (), u64, i64)\")))"
1015        );
1016        assert_eq!(states.initialized_count(), 1);
1017        assert_eq!(states.strong_count(), 1);
1018
1019        // We can add a shard of a different type.
1020        let s2 = ShardId::new();
1021        let s2_state1 = states
1022            .get::<String, (), u64, i64, _, _>(
1023                s2,
1024                || async { Ok(new_state(s2)) },
1025                &Diagnostics::for_tests(),
1026            )
1027            .await
1028            .expect("should successfully initialize");
1029        assert_eq!(states.initialized_count(), 2);
1030        assert_eq!(states.strong_count(), 2);
1031        let s2_state2 = states
1032            .get::<String, (), u64, i64, _, _>(
1033                s2,
1034                || async { Ok(new_state(s2)) },
1035                &Diagnostics::for_tests(),
1036            )
1037            .await
1038            .expect("should successfully initialize");
1039        assert_same(&s2_state1, &s2_state2);
1040
1041        // The cache holds weak references to State so we reclaim memory if the
1042        // shards stops being used.
1043        drop(s1_state1);
1044        assert_eq!(states.strong_count(), 2);
1045        drop(s1_state2);
1046        assert_eq!(states.strong_count(), 1);
1047        assert_eq!(states.initialized_count(), 2);
1048        assert_none!(states.get_cached(&s1));
1049
1050        // But we can re-init that shard if necessary.
1051        let s1_state1 = states
1052            .get::<(), (), u64, i64, _, _>(
1053                s1,
1054                || async { Ok(new_state(s1)) },
1055                &Diagnostics::for_tests(),
1056            )
1057            .await
1058            .expect("should successfully initialize");
1059        assert_eq!(states.initialized_count(), 2);
1060        assert_eq!(states.strong_count(), 2);
1061        drop(s1_state1);
1062        assert_eq!(states.strong_count(), 1);
1063    }
1064
1065    #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
1066    #[cfg_attr(miri, ignore)] // too slow
1067    async fn state_cache_concurrency() {
1068        mz_ore::test::init_logging();
1069
1070        const COUNT: usize = 1000;
1071        let id = ShardId::new();
1072        let cache = StateCache::new_no_metrics();
1073        let diagnostics = Diagnostics::for_tests();
1074
1075        let mut futures = (0..COUNT)
1076            .map(|_| {
1077                cache.get::<(), (), u64, i64, _, _>(
1078                    id,
1079                    || async {
1080                        Ok(TypedState::new(
1081                            DUMMY_BUILD_INFO.semver_version(),
1082                            id,
1083                            "host".into(),
1084                            0,
1085                        ))
1086                    },
1087                    &diagnostics,
1088                )
1089            })
1090            .collect::<FuturesUnordered<_>>();
1091
1092        for _ in 0..COUNT {
1093            let _ = futures.next().await.unwrap();
1094        }
1095    }
1096
1097    #[mz_ore::test(tokio::test)]
1098    #[cfg_attr(miri, ignore)] // too slow
1099    async fn update_semaphore() {
1100        // Check that the update lease mechanism is not susceptible to futurelock.
1101        // If there is an issue, this test will time out.
1102        mz_ore::test::init_logging();
1103
1104        let shard_id = ShardId::new();
1105        let persist_config = Arc::new(PersistConfig::new_for_tests());
1106        let pubsub = Arc::new(NoopPubSubSender);
1107        let state: LockingTypedState<String, (), u64, i64> = LockingTypedState::new(
1108            shard_id,
1109            TypedState::new(
1110                DUMMY_BUILD_INFO.semver_version(),
1111                shard_id,
1112                "host".into(),
1113                0,
1114            ),
1115            Arc::new(Metrics::new(&*persist_config, &MetricsRegistry::new())),
1116            persist_config,
1117            pubsub.subscribe(&shard_id),
1118            &Diagnostics::for_tests(),
1119        );
1120
1121        // Initialize three futures, all of which will grab a lease and then poll a oneshot,
1122        // which allows us to externally trigger which ones will complete.
1123        let mk_future = || {
1124            let (tx, rx) = oneshot::channel();
1125            let future = async {
1126                let lease = state.lease_for_update().await;
1127                let () = rx.await.unwrap();
1128                drop(lease);
1129            };
1130            (future, tx)
1131        };
1132
1133        let (one, _one_tx) = mk_future();
1134        let (two, _two_tx) = mk_future();
1135        let (three, three_tx) = mk_future();
1136        let mut one = pin!(one);
1137        let mut two = pin!(two);
1138        let mut three = pin!(three);
1139
1140        // Poll all the futures, but fall through to the default case, since none are ready.
1141        tokio::select! { biased;
1142            _ = &mut one => { unreachable!() }
1143            _ = &mut two => { unreachable!() }
1144            _ = &mut three => { unreachable!() }
1145            _ = async {} => {}
1146        }
1147
1148        // Allow the third future to complete.
1149        three_tx.send(()).unwrap();
1150
1151        // Poll all the futures but the second future. This shouldn't hang, since the third future
1152        // is now ready to go and the others should eventually time out.
1153        tokio::select! { biased;
1154            _ = &mut one => { unreachable!() }
1155            _ = &mut three => {  }
1156        }
1157    }
1158
1159    #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
1160    #[cfg_attr(miri, ignore)] // too slow
1161    async fn update_semaphore_stress() {
1162        // Check that the update lease mechanism is not susceptible to futurelock.
1163        // If there is an issue, this test will time out.
1164        mz_ore::test::init_logging();
1165
1166        const TIMEOUT: Duration = Duration::from_millis(100);
1167        const COUNT: u64 = 100;
1168
1169        let shard_id = ShardId::new();
1170        let persist_config = Arc::new(PersistConfig::new_for_tests());
1171        persist_config.set_config(&STATE_UPDATE_LEASE_TIMEOUT, TIMEOUT);
1172        let pubsub = Arc::new(NoopPubSubSender);
1173        let state: LockingTypedState<String, (), u64, i64> = LockingTypedState::new(
1174            shard_id,
1175            TypedState::new(
1176                DUMMY_BUILD_INFO.semver_version(),
1177                shard_id,
1178                "host".into(),
1179                0,
1180            ),
1181            Arc::new(Metrics::new(&*persist_config, &MetricsRegistry::new())),
1182            persist_config,
1183            pubsub.subscribe(&shard_id),
1184            &Diagnostics::for_tests(),
1185        );
1186
1187        let mut futures = (0..(COUNT * 3))
1188            .map(async |i| {
1189                state.lease_for_update().await;
1190                // Either hang forever, succeed quickly, or succeed after hitting the timeout.
1191                match i % 3 {
1192                    0 => {
1193                        let () = std::future::pending().await;
1194                    }
1195                    1 => {
1196                        tokio::time::sleep(Duration::from_millis(i)).await;
1197                    }
1198                    _ => {
1199                        tokio::time::sleep(Duration::from_millis(i) + TIMEOUT).await;
1200                    }
1201                }
1202            })
1203            .collect::<FuturesUnordered<_>>();
1204
1205        // All the futures that don't themselves hang forever should resolve.
1206        for _ in 0..(COUNT * 2) {
1207            let _ = futures.next().await.unwrap();
1208        }
1209    }
1210}