1use std::collections::BTreeMap;
13use std::fmt::{Debug, Formatter};
14use std::hash::Hash;
15use std::sync::Arc;
16use std::time::Duration;
17
18use differential_dataflow::consolidation::{consolidate, consolidate_updates};
19use mz_ore::collections::{AssociativeExt, HashSet};
20use mz_ore::soft_panic_or_log;
21use mz_persist_client::critical::SinceHandle;
22use mz_persist_client::error::UpperMismatch;
23use mz_persist_client::read::{ListenEvent, Subscribe};
24use mz_persist_client::write::WriteHandle;
25use mz_persist_client::{Diagnostics, PersistClient};
26use mz_persist_types::{Codec, ShardId};
27use timely::progress::Antichain;
28use tracing::debug;
29
30pub trait DurableCacheCodec: Debug + Eq {
31 type Key: Ord + Hash + Clone + Debug;
32 type Val: Eq + Debug;
33 type KeyCodec: Codec + Ord + Debug + Clone;
34 type ValCodec: Codec + Ord + Debug + Clone;
35
36 fn schemas() -> (
37 <Self::KeyCodec as Codec>::Schema,
38 <Self::ValCodec as Codec>::Schema,
39 );
40 fn encode(key: &Self::Key, val: &Self::Val) -> (Self::KeyCodec, Self::ValCodec);
41 fn decode(key: &Self::KeyCodec, val: &Self::ValCodec) -> (Self::Key, Self::Val);
42}
43
44#[derive(Debug)]
45pub enum Error {
46 WriteConflict(UpperMismatch<u64>),
47}
48
49impl std::fmt::Display for Error {
50 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
51 match self {
52 Error::WriteConflict(err) => write!(f, "{err}"),
53 }
54 }
55}
56
57#[derive(Debug, PartialEq, Eq)]
58struct LocalVal<C: DurableCacheCodec> {
59 encoded_key: C::KeyCodec,
60 decoded_val: C::Val,
61 encoded_val: C::ValCodec,
62}
63
64#[derive(Debug)]
65pub struct DurableCache<C: DurableCacheCodec> {
66 since_handle: SinceHandle<C::KeyCodec, C::ValCodec, u64, i64, i64>,
67 write: WriteHandle<C::KeyCodec, C::ValCodec, u64, i64>,
68 subscribe: Subscribe<C::KeyCodec, C::ValCodec, u64, i64>,
69
70 local: BTreeMap<C::Key, LocalVal<C>>,
71 local_progress: u64,
72}
73
74const USE_CRITICAL_SINCE: bool = true;
75
76impl<C: DurableCacheCodec> DurableCache<C> {
77 pub async fn new(persist: &PersistClient, shard_id: ShardId, purpose: &str) -> Self {
79 let diagnostics = Diagnostics {
80 shard_name: format!("{purpose}_cache"),
81 handle_purpose: format!("durable persist cache: {purpose}"),
82 };
83 let since_handle = persist
84 .open_critical_since(
85 shard_id,
86 PersistClient::CONTROLLER_CRITICAL_SINCE,
89 diagnostics.clone(),
90 )
91 .await
92 .expect("invalid usage");
93 let (key_schema, val_schema) = C::schemas();
94 let (mut write, read) = persist
95 .open(
96 shard_id,
97 Arc::new(key_schema),
98 Arc::new(val_schema),
99 diagnostics,
100 USE_CRITICAL_SINCE,
101 )
102 .await
103 .expect("shard codecs should not change");
104 let res = write
106 .compare_and_append_batch(
107 &mut [],
108 Antichain::from_elem(0),
109 Antichain::from_elem(1),
110 true,
111 )
112 .await
113 .expect("usage was valid");
114 match res {
115 Ok(()) => {}
117 Err(UpperMismatch { .. }) => {}
119 }
120
121 let as_of = read.since().clone();
122 let subscribe = read
123 .subscribe(as_of)
124 .await
125 .expect("capability should be held at this since");
126 let mut ret = DurableCache {
127 since_handle,
128 write,
129 subscribe,
130 local: BTreeMap::new(),
131 local_progress: 0,
132 };
133 ret.sync_to(ret.write.upper().as_option().copied()).await;
134 ret
135 }
136
137 async fn sync_to(&mut self, progress: Option<u64>) -> u64 {
138 let progress = progress.expect("cache shard should not be closed");
139 let mut updates: BTreeMap<_, Vec<_>> = BTreeMap::new();
140
141 while self.local_progress < progress {
142 let events = self.subscribe.fetch_next().await;
143 for event in events {
144 match event {
145 ListenEvent::Updates(batch_updates) => {
146 debug!("syncing updates {batch_updates:?}");
147 for update in batch_updates {
148 updates.entry(update.1).or_default().push(update);
149 }
150 }
151 ListenEvent::Progress(x) => {
152 debug!("synced up to {x:?}");
153 self.local_progress =
154 x.into_option().expect("cache shard should not be closed");
155 while let Some((ts, mut updates)) = updates.pop_first() {
158 assert!(
159 ts < self.local_progress,
160 "expected {} < {}",
161 ts,
162 self.local_progress
163 );
164 assert!(
165 updates.iter().all(|(_, update_ts, _)| ts == *update_ts),
166 "all updates should be for time {ts}, updates: {updates:?}"
167 );
168
169 consolidate_updates(&mut updates);
170 updates.sort_by(|(_, _, d1), (_, _, d2)| d1.cmp(d2));
171 for ((k, v), t, d) in updates {
172 let encoded_key = k.unwrap();
173 let encoded_val = v.unwrap();
174 let (decoded_key, decoded_val) =
175 C::decode(&encoded_key, &encoded_val);
176 let val = LocalVal {
177 encoded_key,
178 decoded_val,
179 encoded_val,
180 };
181
182 if d == 1 {
183 self.local.expect_insert(
184 decoded_key,
185 val,
186 "duplicate cache entry",
187 );
188 } else if d == -1 {
189 let prev = self
190 .local
191 .expect_remove(&decoded_key, "entry does not exist");
192 assert_eq!(
193 val, prev,
194 "removed val does not match expected val"
195 );
196 } else {
197 panic!(
198 "unexpected diff: (({:?}, {:?}), {}, {})",
199 decoded_key, val.decoded_val, t, d
200 );
201 }
202 }
203 }
204 }
205 }
206 }
207 }
208 assert_eq!(updates, BTreeMap::new(), "all updates should be applied");
209 progress
210 }
211
212 pub fn get_local(&self, key: &C::Key) -> Option<&C::Val> {
215 self.local.get(key).map(|val| &val.decoded_val)
216 }
217
218 pub async fn get(&mut self, key: &C::Key, val_fn: impl FnOnce() -> C::Val) -> &C::Val {
222 if self.local.contains_key(key) {
226 return self.get_local(key).expect("checked above");
227 }
228
229 self.sync_to(self.write.shared_upper().into_option()).await;
232 if self.local.contains_key(key) {
233 return self.get_local(key).expect("checked above");
234 }
235
236 let val = val_fn();
238 let mut expected_upper = self.local_progress;
239 let update = (C::encode(key, &val), 1);
240 loop {
241 let ret = self
242 .compare_and_append([update.clone()], expected_upper)
243 .await;
244 match ret {
245 Ok(new_upper) => {
246 self.sync_to(Some(new_upper)).await;
247 return self.get_local(key).expect("just inserted");
248 }
249 Err(err) => {
250 expected_upper = self.sync_to(err.current.into_option()).await;
251 if self.local.contains_key(key) {
252 return self.get_local(key).expect("checked above");
253 }
254 continue;
255 }
256 }
257 }
258 }
259
260 pub fn entries_local(&self) -> impl Iterator<Item = (&C::Key, &C::Val)> {
262 self.local.iter().map(|(key, val)| (key, &val.decoded_val))
263 }
264
265 pub async fn set(&mut self, key: &C::Key, value: Option<&C::Val>) {
269 while let Err(err) = self.try_set(key, value).await {
270 debug!("failed to set entry: {err} ... retrying");
271 }
272 }
273
274 pub async fn set_many(&mut self, entries: &[(&C::Key, Option<&C::Val>)]) {
279 while let Err(err) = self.try_set_many(entries).await {
280 debug!("failed to set entries: {err} ... retrying");
281 }
282 }
283
284 pub async fn try_set(&mut self, key: &C::Key, value: Option<&C::Val>) -> Result<(), Error> {
289 self.try_set_many(&[(key, value)]).await
290 }
291
292 pub async fn try_set_many(
298 &mut self,
299 entries: &[(&C::Key, Option<&C::Val>)],
300 ) -> Result<(), Error> {
301 let expected_upper = self.local_progress;
302 let mut updates = Vec::new();
303 let mut seen_keys = HashSet::new();
304
305 for (key, val) in entries {
306 if seen_keys.insert(key) {
308 if let Some(prev) = self.local.get(key) {
309 updates.push(((prev.encoded_key.clone(), prev.encoded_val.clone()), -1));
310 }
311 if let Some(val) = val {
312 updates.push((C::encode(key, val), 1));
313 }
314 }
315 }
316 consolidate(&mut updates);
317
318 let ret = self.compare_and_append(updates, expected_upper).await;
319 match ret {
320 Ok(new_upper) => {
321 self.sync_to(Some(new_upper)).await;
322 Ok(())
323 }
324 Err(err) => {
325 self.sync_to(err.current.clone().into_option()).await;
326 Err(Error::WriteConflict(err))
327 }
328 }
329 }
330
331 async fn compare_and_append<I>(
337 &mut self,
338 updates: I,
339 write_ts: u64,
340 ) -> Result<u64, UpperMismatch<u64>>
341 where
342 I: IntoIterator<Item = ((C::KeyCodec, C::ValCodec), i64)>,
347 {
348 let expected_upper = write_ts;
349 let new_upper = expected_upper + 1;
350 let updates = updates.into_iter().map(|((k, v), d)| ((k, v), write_ts, d));
351 self.write
352 .compare_and_append(
353 updates,
354 Antichain::from_elem(expected_upper),
355 Antichain::from_elem(new_upper),
356 )
357 .await
358 .expect("usage should be valid")?;
359
360 let downgrade_to = Antichain::from_elem(write_ts);
362
363 let opaque = *self.since_handle.opaque();
368 let ret = self
369 .since_handle
370 .compare_and_downgrade_since(&opaque, (&opaque, &downgrade_to))
371 .await;
372 if let Err(e) = ret {
373 soft_panic_or_log!("found opaque value {e}, but expected {opaque}");
374 }
375
376 Ok(new_upper)
377 }
378
379 pub async fn dangerous_compact_shard(
382 &self,
383 fuel: impl Fn() -> usize,
384 wait: impl Fn() -> Duration,
385 ) {
386 mz_persist_client::cli::admin::dangerous_force_compaction_and_break_pushdown(
387 &self.write,
388 fuel,
389 wait,
390 )
391 .await
392 }
393}
394
395#[cfg(test)]
396mod tests {
397 use mz_ore::assert_none;
398 use mz_persist_client::cache::PersistClientCache;
399 use mz_persist_types::PersistLocation;
400 use mz_persist_types::codec_impls::StringSchema;
401
402 use super::*;
403
404 #[derive(Debug, PartialEq, Eq)]
405 struct TestCodec;
406
407 impl DurableCacheCodec for TestCodec {
408 type Key = String;
409 type Val = String;
410 type KeyCodec = String;
411 type ValCodec = String;
412
413 fn schemas() -> (
414 <Self::KeyCodec as Codec>::Schema,
415 <Self::ValCodec as Codec>::Schema,
416 ) {
417 (StringSchema, StringSchema)
418 }
419
420 fn encode(key: &Self::Key, val: &Self::Val) -> (Self::KeyCodec, Self::ValCodec) {
421 (key.clone(), val.clone())
422 }
423
424 fn decode(key: &Self::KeyCodec, val: &Self::ValCodec) -> (Self::Key, Self::Val) {
425 (key.clone(), val.clone())
426 }
427 }
428
429 #[mz_ore::test(tokio::test)]
430 #[cfg_attr(miri, ignore)]
431 async fn durable_cache() {
432 let persist = PersistClientCache::new_no_metrics();
433 let persist = persist
434 .open(PersistLocation::new_in_mem())
435 .await
436 .expect("location should be valid");
437 let shard_id = ShardId::new();
438
439 let mut cache0 = DurableCache::<TestCodec>::new(&persist, shard_id, "test1").await;
440 assert_none!(cache0.get_local(&"foo".into()));
441 assert_eq!(cache0.get(&"foo".into(), || "bar".into()).await, "bar");
442 assert_eq!(
443 cache0.entries_local().collect::<Vec<_>>(),
444 vec![(&"foo".into(), &"bar".into())]
445 );
446
447 cache0.set(&"k1".into(), Some(&"v1".into())).await;
448 cache0.set(&"k2".into(), Some(&"v2".into())).await;
449 assert_eq!(cache0.get_local(&"k1".into()), Some(&"v1".into()));
450 assert_eq!(cache0.get(&"k1".into(), || "v10".into()).await, &"v1");
451 assert_eq!(cache0.get_local(&"k2".into()), Some(&"v2".into()));
452 assert_eq!(cache0.get(&"k2".into(), || "v20".into()).await, &"v2");
453 assert_eq!(
454 cache0.entries_local().collect::<Vec<_>>(),
455 vec![
456 (&"foo".into(), &"bar".into()),
457 (&"k1".into(), &"v1".into()),
458 (&"k2".into(), &"v2".into())
459 ]
460 );
461
462 cache0.set(&"k1".into(), None).await;
463 assert_none!(cache0.get_local(&"k1".into()));
464 assert_eq!(
465 cache0.entries_local().collect::<Vec<_>>(),
466 vec![(&"foo".into(), &"bar".into()), (&"k2".into(), &"v2".into())]
467 );
468
469 cache0
470 .set_many(&[
471 (&"k1".into(), Some(&"v10".into())),
472 (&"k2".into(), None),
473 (&"k3".into(), None),
474 ])
475 .await;
476 assert_eq!(cache0.get_local(&"k1".into()), Some(&"v10".into()));
477 assert_none!(cache0.get_local(&"k2".into()));
478 assert_none!(cache0.get_local(&"k3".into()));
479 assert_eq!(
480 cache0.entries_local().collect::<Vec<_>>(),
481 vec![
482 (&"foo".into(), &"bar".into()),
483 (&"k1".into(), &"v10".into()),
484 ]
485 );
486
487 cache0
488 .set_many(&[
489 (&"k4".into(), Some(&"v40".into())),
490 (&"k4".into(), Some(&"v41".into())),
491 (&"k4".into(), Some(&"v42".into())),
492 (&"k5".into(), Some(&"v50".into())),
493 (&"k5".into(), Some(&"v51".into())),
494 (&"k5".into(), Some(&"v52".into())),
495 ])
496 .await;
497 assert_eq!(cache0.get_local(&"k4".into()), Some(&"v40".into()));
498 assert_eq!(cache0.get_local(&"k5".into()), Some(&"v50".into()));
499 assert_eq!(
500 cache0.entries_local().collect::<Vec<_>>(),
501 vec![
502 (&"foo".into(), &"bar".into()),
503 (&"k1".into(), &"v10".into()),
504 (&"k4".into(), &"v40".into()),
505 (&"k5".into(), &"v50".into()),
506 ]
507 );
508
509 let mut cache1 = DurableCache::<TestCodec>::new(&persist, shard_id, "test2").await;
510 assert_eq!(cache1.get(&"foo".into(), || panic!("boom")).await, "bar");
511 assert_eq!(cache1.get(&"k1".into(), || panic!("boom")).await, &"v10");
512 assert_none!(cache1.get_local(&"k2".into()));
513 assert_none!(cache1.get_local(&"k3".into()));
514 assert_eq!(
515 cache1.entries_local().collect::<Vec<_>>(),
516 vec![
517 (&"foo".into(), &"bar".into()),
518 (&"k1".into(), &"v10".into()),
519 (&"k4".into(), &"v40".into()),
520 (&"k5".into(), &"v50".into()),
521 ]
522 );
523
524 let fuel = || 131_072;
526 let wait = || Duration::from_millis(0);
527 cache1.dangerous_compact_shard(fuel, wait).await
528 }
529}