1use std::collections::BTreeMap;
13use std::fmt::{Debug, Formatter};
14use std::hash::Hash;
15use std::sync::Arc;
16use std::time::Duration;
17
18use differential_dataflow::consolidation::{consolidate, consolidate_updates};
19use mz_ore::collections::{AssociativeExt, HashSet};
20use mz_ore::soft_panic_or_log;
21use mz_persist_client::critical::{Opaque, SinceHandle};
22use mz_persist_client::error::UpperMismatch;
23use mz_persist_client::read::{ListenEvent, Subscribe};
24use mz_persist_client::write::WriteHandle;
25use mz_persist_client::{Diagnostics, PersistClient};
26use mz_persist_types::{Codec, ShardId};
27use timely::progress::Antichain;
28use tracing::debug;
29
30pub trait DurableCacheCodec: Debug + Eq {
31 type Key: Ord + Hash + Clone + Debug;
32 type Val: Eq + Debug;
33 type KeyCodec: Codec + Ord + Debug + Clone;
34 type ValCodec: Codec + Ord + Debug + Clone;
35
36 fn schemas() -> (
37 <Self::KeyCodec as Codec>::Schema,
38 <Self::ValCodec as Codec>::Schema,
39 );
40 fn encode(key: &Self::Key, val: &Self::Val) -> (Self::KeyCodec, Self::ValCodec);
41 fn decode(key: &Self::KeyCodec, val: &Self::ValCodec) -> (Self::Key, Self::Val);
42}
43
44#[derive(Debug)]
45pub enum Error {
46 WriteConflict(UpperMismatch<u64>),
47}
48
49impl std::fmt::Display for Error {
50 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
51 match self {
52 Error::WriteConflict(err) => write!(f, "{err}"),
53 }
54 }
55}
56
57#[derive(Debug, PartialEq, Eq)]
58struct LocalVal<C: DurableCacheCodec> {
59 encoded_key: C::KeyCodec,
60 decoded_val: C::Val,
61 encoded_val: C::ValCodec,
62}
63
64#[derive(Debug)]
65pub struct DurableCache<C: DurableCacheCodec> {
66 since_handle: SinceHandle<C::KeyCodec, C::ValCodec, u64, i64>,
67 write: WriteHandle<C::KeyCodec, C::ValCodec, u64, i64>,
68 subscribe: Subscribe<C::KeyCodec, C::ValCodec, u64, i64>,
69
70 local: BTreeMap<C::Key, LocalVal<C>>,
71 local_progress: u64,
72}
73
74const USE_CRITICAL_SINCE: bool = true;
75
76impl<C: DurableCacheCodec> DurableCache<C> {
77 pub async fn new(persist: &PersistClient, shard_id: ShardId, purpose: &str) -> Self {
79 let diagnostics = Diagnostics {
80 shard_name: format!("{purpose}_cache"),
81 handle_purpose: format!("durable persist cache: {purpose}"),
82 };
83 let since_handle = persist
84 .open_critical_since(
85 shard_id,
86 PersistClient::CONTROLLER_CRITICAL_SINCE,
89 Opaque::encode(&i64::MIN),
90 diagnostics.clone(),
91 )
92 .await
93 .expect("invalid usage");
94 let (key_schema, val_schema) = C::schemas();
95 let (mut write, read) = persist
96 .open(
97 shard_id,
98 Arc::new(key_schema),
99 Arc::new(val_schema),
100 diagnostics,
101 USE_CRITICAL_SINCE,
102 )
103 .await
104 .expect("shard codecs should not change");
105 let res = write
107 .compare_and_append_batch(
108 &mut [],
109 Antichain::from_elem(0),
110 Antichain::from_elem(1),
111 true,
112 )
113 .await
114 .expect("usage was valid");
115 match res {
116 Ok(()) => {}
118 Err(UpperMismatch { .. }) => {}
120 }
121
122 let as_of = read.since().clone();
123 let subscribe = read
124 .subscribe(as_of)
125 .await
126 .expect("capability should be held at this since");
127 let mut ret = DurableCache {
128 since_handle,
129 write,
130 subscribe,
131 local: BTreeMap::new(),
132 local_progress: 0,
133 };
134 ret.sync_to(ret.write.upper().as_option().copied()).await;
135 ret
136 }
137
138 async fn sync_to(&mut self, progress: Option<u64>) -> u64 {
139 let progress = progress.expect("cache shard should not be closed");
140 let mut updates: BTreeMap<_, Vec<_>> = BTreeMap::new();
141
142 while self.local_progress < progress {
143 let events = self.subscribe.fetch_next().await;
144 for event in events {
145 match event {
146 ListenEvent::Updates(batch_updates) => {
147 debug!("syncing updates {batch_updates:?}");
148 for update in batch_updates {
149 updates.entry(update.1).or_default().push(update);
150 }
151 }
152 ListenEvent::Progress(x) => {
153 debug!("synced up to {x:?}");
154 self.local_progress =
155 x.into_option().expect("cache shard should not be closed");
156 while let Some((ts, mut updates)) = updates.pop_first() {
159 assert!(
160 ts < self.local_progress,
161 "expected {} < {}",
162 ts,
163 self.local_progress
164 );
165 assert!(
166 updates.iter().all(|(_, update_ts, _)| ts == *update_ts),
167 "all updates should be for time {ts}, updates: {updates:?}"
168 );
169
170 consolidate_updates(&mut updates);
171 updates.sort_by(|(_, _, d1), (_, _, d2)| d1.cmp(d2));
172 for ((k, v), t, d) in updates {
173 let encoded_key = k;
174 let encoded_val = v;
175 let (decoded_key, decoded_val) =
176 C::decode(&encoded_key, &encoded_val);
177 let val = LocalVal {
178 encoded_key,
179 decoded_val,
180 encoded_val,
181 };
182
183 if d == 1 {
184 self.local.expect_insert(
185 decoded_key,
186 val,
187 "duplicate cache entry",
188 );
189 } else if d == -1 {
190 let prev = self
191 .local
192 .expect_remove(&decoded_key, "entry does not exist");
193 assert_eq!(
194 val, prev,
195 "removed val does not match expected val"
196 );
197 } else {
198 panic!(
199 "unexpected diff: (({:?}, {:?}), {}, {})",
200 decoded_key, val.decoded_val, t, d
201 );
202 }
203 }
204 }
205 }
206 }
207 }
208 }
209 assert_eq!(updates, BTreeMap::new(), "all updates should be applied");
210 progress
211 }
212
213 pub fn get_local(&self, key: &C::Key) -> Option<&C::Val> {
216 self.local.get(key).map(|val| &val.decoded_val)
217 }
218
219 pub async fn get(&mut self, key: &C::Key, val_fn: impl FnOnce() -> C::Val) -> &C::Val {
223 if self.local.contains_key(key) {
227 return self.get_local(key).expect("checked above");
228 }
229
230 self.sync_to(self.write.shared_upper().into_option()).await;
233 if self.local.contains_key(key) {
234 return self.get_local(key).expect("checked above");
235 }
236
237 let val = val_fn();
239 let mut expected_upper = self.local_progress;
240 let update = (C::encode(key, &val), 1);
241 loop {
242 let ret = self
243 .compare_and_append([update.clone()], expected_upper)
244 .await;
245 match ret {
246 Ok(new_upper) => {
247 self.sync_to(Some(new_upper)).await;
248 return self.get_local(key).expect("just inserted");
249 }
250 Err(err) => {
251 expected_upper = self.sync_to(err.current.into_option()).await;
252 if self.local.contains_key(key) {
253 return self.get_local(key).expect("checked above");
254 }
255 continue;
256 }
257 }
258 }
259 }
260
261 pub fn entries_local(&self) -> impl Iterator<Item = (&C::Key, &C::Val)> {
263 self.local.iter().map(|(key, val)| (key, &val.decoded_val))
264 }
265
266 pub async fn set(&mut self, key: &C::Key, value: Option<&C::Val>) {
270 while let Err(err) = self.try_set(key, value).await {
271 debug!("failed to set entry: {err} ... retrying");
272 }
273 }
274
275 pub async fn set_many(&mut self, entries: &[(&C::Key, Option<&C::Val>)]) {
280 while let Err(err) = self.try_set_many(entries).await {
281 debug!("failed to set entries: {err} ... retrying");
282 }
283 }
284
285 pub async fn try_set(&mut self, key: &C::Key, value: Option<&C::Val>) -> Result<(), Error> {
290 self.try_set_many(&[(key, value)]).await
291 }
292
293 pub async fn try_set_many(
299 &mut self,
300 entries: &[(&C::Key, Option<&C::Val>)],
301 ) -> Result<(), Error> {
302 let expected_upper = self.local_progress;
303 let mut updates = Vec::new();
304 let mut seen_keys = HashSet::new();
305
306 for (key, val) in entries {
307 if seen_keys.insert(key) {
309 if let Some(prev) = self.local.get(key) {
310 updates.push(((prev.encoded_key.clone(), prev.encoded_val.clone()), -1));
311 }
312 if let Some(val) = val {
313 updates.push((C::encode(key, val), 1));
314 }
315 }
316 }
317 consolidate(&mut updates);
318
319 let ret = self.compare_and_append(updates, expected_upper).await;
320 match ret {
321 Ok(new_upper) => {
322 self.sync_to(Some(new_upper)).await;
323 Ok(())
324 }
325 Err(err) => {
326 self.sync_to(err.current.clone().into_option()).await;
327 Err(Error::WriteConflict(err))
328 }
329 }
330 }
331
332 async fn compare_and_append<I>(
338 &mut self,
339 updates: I,
340 write_ts: u64,
341 ) -> Result<u64, UpperMismatch<u64>>
342 where
343 I: IntoIterator<Item = ((C::KeyCodec, C::ValCodec), i64)>,
348 {
349 let expected_upper = write_ts;
350 let new_upper = expected_upper + 1;
351 let updates = updates.into_iter().map(|((k, v), d)| ((k, v), write_ts, d));
352 self.write
353 .compare_and_append(
354 updates,
355 Antichain::from_elem(expected_upper),
356 Antichain::from_elem(new_upper),
357 )
358 .await
359 .expect("usage should be valid")?;
360
361 let downgrade_to = Antichain::from_elem(write_ts);
363
364 let opaque = self.since_handle.opaque().clone();
369 let ret = self
370 .since_handle
371 .compare_and_downgrade_since(&opaque, (&opaque, &downgrade_to))
372 .await;
373 if let Err(e) = ret {
374 soft_panic_or_log!("found opaque value {e:?}, but expected {opaque:?}");
375 }
376
377 Ok(new_upper)
378 }
379
380 pub async fn dangerous_compact_shard(
383 &self,
384 fuel: impl Fn() -> usize,
385 wait: impl Fn() -> Duration,
386 ) {
387 mz_persist_client::cli::admin::dangerous_force_compaction_and_break_pushdown(
388 &self.write,
389 fuel,
390 wait,
391 )
392 .await
393 }
394
395 pub async fn upgrade_version(&self) {
398 self.since_handle
399 .upgrade_version()
400 .await
401 .expect("invalid usage")
402 }
403}
404
405#[cfg(test)]
406mod tests {
407 use mz_ore::assert_none;
408 use mz_persist_client::cache::PersistClientCache;
409 use mz_persist_types::PersistLocation;
410 use mz_persist_types::codec_impls::StringSchema;
411
412 use super::*;
413
414 #[derive(Debug, PartialEq, Eq)]
415 struct TestCodec;
416
417 impl DurableCacheCodec for TestCodec {
418 type Key = String;
419 type Val = String;
420 type KeyCodec = String;
421 type ValCodec = String;
422
423 fn schemas() -> (
424 <Self::KeyCodec as Codec>::Schema,
425 <Self::ValCodec as Codec>::Schema,
426 ) {
427 (StringSchema, StringSchema)
428 }
429
430 fn encode(key: &Self::Key, val: &Self::Val) -> (Self::KeyCodec, Self::ValCodec) {
431 (key.clone(), val.clone())
432 }
433
434 fn decode(key: &Self::KeyCodec, val: &Self::ValCodec) -> (Self::Key, Self::Val) {
435 (key.clone(), val.clone())
436 }
437 }
438
439 #[mz_ore::test(tokio::test)]
440 #[cfg_attr(miri, ignore)]
441 async fn durable_cache() {
442 let persist = PersistClientCache::new_no_metrics();
443 let persist = persist
444 .open(PersistLocation::new_in_mem())
445 .await
446 .expect("location should be valid");
447 let shard_id = ShardId::new();
448
449 let mut cache0 = DurableCache::<TestCodec>::new(&persist, shard_id, "test1").await;
450 assert_none!(cache0.get_local(&"foo".into()));
451 assert_eq!(cache0.get(&"foo".into(), || "bar".into()).await, "bar");
452 assert_eq!(
453 cache0.entries_local().collect::<Vec<_>>(),
454 vec![(&"foo".into(), &"bar".into())]
455 );
456
457 cache0.set(&"k1".into(), Some(&"v1".into())).await;
458 cache0.set(&"k2".into(), Some(&"v2".into())).await;
459 assert_eq!(cache0.get_local(&"k1".into()), Some(&"v1".into()));
460 assert_eq!(cache0.get(&"k1".into(), || "v10".into()).await, &"v1");
461 assert_eq!(cache0.get_local(&"k2".into()), Some(&"v2".into()));
462 assert_eq!(cache0.get(&"k2".into(), || "v20".into()).await, &"v2");
463 assert_eq!(
464 cache0.entries_local().collect::<Vec<_>>(),
465 vec![
466 (&"foo".into(), &"bar".into()),
467 (&"k1".into(), &"v1".into()),
468 (&"k2".into(), &"v2".into())
469 ]
470 );
471
472 cache0.set(&"k1".into(), None).await;
473 assert_none!(cache0.get_local(&"k1".into()));
474 assert_eq!(
475 cache0.entries_local().collect::<Vec<_>>(),
476 vec![(&"foo".into(), &"bar".into()), (&"k2".into(), &"v2".into())]
477 );
478
479 cache0
480 .set_many(&[
481 (&"k1".into(), Some(&"v10".into())),
482 (&"k2".into(), None),
483 (&"k3".into(), None),
484 ])
485 .await;
486 assert_eq!(cache0.get_local(&"k1".into()), Some(&"v10".into()));
487 assert_none!(cache0.get_local(&"k2".into()));
488 assert_none!(cache0.get_local(&"k3".into()));
489 assert_eq!(
490 cache0.entries_local().collect::<Vec<_>>(),
491 vec![
492 (&"foo".into(), &"bar".into()),
493 (&"k1".into(), &"v10".into()),
494 ]
495 );
496
497 cache0
498 .set_many(&[
499 (&"k4".into(), Some(&"v40".into())),
500 (&"k4".into(), Some(&"v41".into())),
501 (&"k4".into(), Some(&"v42".into())),
502 (&"k5".into(), Some(&"v50".into())),
503 (&"k5".into(), Some(&"v51".into())),
504 (&"k5".into(), Some(&"v52".into())),
505 ])
506 .await;
507 assert_eq!(cache0.get_local(&"k4".into()), Some(&"v40".into()));
508 assert_eq!(cache0.get_local(&"k5".into()), Some(&"v50".into()));
509 assert_eq!(
510 cache0.entries_local().collect::<Vec<_>>(),
511 vec![
512 (&"foo".into(), &"bar".into()),
513 (&"k1".into(), &"v10".into()),
514 (&"k4".into(), &"v40".into()),
515 (&"k5".into(), &"v50".into()),
516 ]
517 );
518
519 let mut cache1 = DurableCache::<TestCodec>::new(&persist, shard_id, "test2").await;
520 assert_eq!(cache1.get(&"foo".into(), || panic!("boom")).await, "bar");
521 assert_eq!(cache1.get(&"k1".into(), || panic!("boom")).await, &"v10");
522 assert_none!(cache1.get_local(&"k2".into()));
523 assert_none!(cache1.get_local(&"k3".into()));
524 assert_eq!(
525 cache1.entries_local().collect::<Vec<_>>(),
526 vec![
527 (&"foo".into(), &"bar".into()),
528 (&"k1".into(), &"v10".into()),
529 (&"k4".into(), &"v40".into()),
530 (&"k5".into(), &"v50".into()),
531 ]
532 );
533
534 let fuel = || 131_072;
536 let wait = || Duration::from_millis(0);
537 cache1.dangerous_compact_shard(fuel, wait).await
538 }
539}