1use std::collections::BTreeMap;
13use std::fmt::{Debug, Formatter};
14use std::hash::Hash;
15use std::sync::Arc;
16use std::time::Duration;
17
18use differential_dataflow::consolidation::{consolidate, consolidate_updates};
19use mz_ore::collections::{AssociativeExt, HashSet};
20use mz_ore::soft_panic_or_log;
21use mz_persist_client::critical::SinceHandle;
22use mz_persist_client::error::UpperMismatch;
23use mz_persist_client::read::{ListenEvent, Subscribe};
24use mz_persist_client::write::WriteHandle;
25use mz_persist_client::{Diagnostics, PersistClient};
26use mz_persist_types::{Codec, ShardId};
27use timely::progress::Antichain;
28use tracing::debug;
29
30pub trait DurableCacheCodec: Debug + Eq {
31 type Key: Ord + Hash + Clone + Debug;
32 type Val: Eq + Debug;
33 type KeyCodec: Codec + Ord + Debug + Clone;
34 type ValCodec: Codec + Ord + Debug + Clone;
35
36 fn schemas() -> (
37 <Self::KeyCodec as Codec>::Schema,
38 <Self::ValCodec as Codec>::Schema,
39 );
40 fn encode(key: &Self::Key, val: &Self::Val) -> (Self::KeyCodec, Self::ValCodec);
41 fn decode(key: &Self::KeyCodec, val: &Self::ValCodec) -> (Self::Key, Self::Val);
42}
43
44#[derive(Debug)]
45pub enum Error {
46 WriteConflict(UpperMismatch<u64>),
47}
48
49impl std::fmt::Display for Error {
50 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
51 match self {
52 Error::WriteConflict(err) => write!(f, "{err}"),
53 }
54 }
55}
56
57#[derive(Debug, PartialEq, Eq)]
58struct LocalVal<C: DurableCacheCodec> {
59 encoded_key: C::KeyCodec,
60 decoded_val: C::Val,
61 encoded_val: C::ValCodec,
62}
63
64#[derive(Debug)]
65pub struct DurableCache<C: DurableCacheCodec> {
66 since_handle: SinceHandle<C::KeyCodec, C::ValCodec, u64, i64, i64>,
67 write: WriteHandle<C::KeyCodec, C::ValCodec, u64, i64>,
68 subscribe: Subscribe<C::KeyCodec, C::ValCodec, u64, i64>,
69
70 local: BTreeMap<C::Key, LocalVal<C>>,
71 local_progress: u64,
72}
73
74const USE_CRITICAL_SINCE: bool = true;
75
76impl<C: DurableCacheCodec> DurableCache<C> {
77 pub async fn new(persist: &PersistClient, shard_id: ShardId, purpose: &str) -> Self {
79 let diagnostics = Diagnostics {
80 shard_name: format!("{purpose}_cache"),
81 handle_purpose: format!("durable persist cache: {purpose}"),
82 };
83 let since_handle = persist
84 .open_critical_since(
85 shard_id,
86 PersistClient::CONTROLLER_CRITICAL_SINCE,
89 diagnostics.clone(),
90 )
91 .await
92 .expect("invalid usage");
93 let (key_schema, val_schema) = C::schemas();
94 let (mut write, read) = persist
95 .open(
96 shard_id,
97 Arc::new(key_schema),
98 Arc::new(val_schema),
99 diagnostics,
100 USE_CRITICAL_SINCE,
101 )
102 .await
103 .expect("shard codecs should not change");
104 let res = write
106 .compare_and_append_batch(
107 &mut [],
108 Antichain::from_elem(0),
109 Antichain::from_elem(1),
110 true,
111 )
112 .await
113 .expect("usage was valid");
114 match res {
115 Ok(()) => {}
117 Err(UpperMismatch { .. }) => {}
119 }
120
121 let as_of = read.since().clone();
122 let subscribe = read
123 .subscribe(as_of)
124 .await
125 .expect("capability should be held at this since");
126 let mut ret = DurableCache {
127 since_handle,
128 write,
129 subscribe,
130 local: BTreeMap::new(),
131 local_progress: 0,
132 };
133 ret.sync_to(ret.write.upper().as_option().copied()).await;
134 ret
135 }
136
137 async fn sync_to(&mut self, progress: Option<u64>) -> u64 {
138 let progress = progress.expect("cache shard should not be closed");
139 let mut updates: BTreeMap<_, Vec<_>> = BTreeMap::new();
140
141 while self.local_progress < progress {
142 let events = self.subscribe.fetch_next().await;
143 for event in events {
144 match event {
145 ListenEvent::Updates(batch_updates) => {
146 debug!("syncing updates {batch_updates:?}");
147 for update in batch_updates {
148 updates.entry(update.1).or_default().push(update);
149 }
150 }
151 ListenEvent::Progress(x) => {
152 debug!("synced up to {x:?}");
153 self.local_progress =
154 x.into_option().expect("cache shard should not be closed");
155 while let Some((ts, mut updates)) = updates.pop_first() {
158 assert!(
159 ts < self.local_progress,
160 "expected {} < {}",
161 ts,
162 self.local_progress
163 );
164 assert!(
165 updates.iter().all(|(_, update_ts, _)| ts == *update_ts),
166 "all updates should be for time {ts}, updates: {updates:?}"
167 );
168
169 consolidate_updates(&mut updates);
170 updates.sort_by(|(_, _, d1), (_, _, d2)| d1.cmp(d2));
171 for ((k, v), t, d) in updates {
172 let encoded_key = k.unwrap();
173 let encoded_val = v.unwrap();
174 let (decoded_key, decoded_val) =
175 C::decode(&encoded_key, &encoded_val);
176 let val = LocalVal {
177 encoded_key,
178 decoded_val,
179 encoded_val,
180 };
181
182 if d == 1 {
183 self.local.expect_insert(
184 decoded_key,
185 val,
186 "duplicate cache entry",
187 );
188 } else if d == -1 {
189 let prev = self
190 .local
191 .expect_remove(&decoded_key, "entry does not exist");
192 assert_eq!(
193 val, prev,
194 "removed val does not match expected val"
195 );
196 } else {
197 panic!(
198 "unexpected diff: (({:?}, {:?}), {}, {})",
199 decoded_key, val.decoded_val, t, d
200 );
201 }
202 }
203 }
204 }
205 }
206 }
207 }
208 assert_eq!(updates, BTreeMap::new(), "all updates should be applied");
209 progress
210 }
211
212 pub fn get_local(&self, key: &C::Key) -> Option<&C::Val> {
215 self.local.get(key).map(|val| &val.decoded_val)
216 }
217
218 pub async fn get(&mut self, key: &C::Key, val_fn: impl FnOnce() -> C::Val) -> &C::Val {
222 if self.local.contains_key(key) {
226 return self.get_local(key).expect("checked above");
227 }
228
229 self.sync_to(self.write.shared_upper().into_option()).await;
232 if self.local.contains_key(key) {
233 return self.get_local(key).expect("checked above");
234 }
235
236 let val = val_fn();
238 let mut expected_upper = self.local_progress;
239 let update = (C::encode(key, &val), 1);
240 loop {
241 let ret = self
242 .compare_and_append([update.clone()], expected_upper)
243 .await;
244 match ret {
245 Ok(new_upper) => {
246 self.sync_to(Some(new_upper)).await;
247 return self.get_local(key).expect("just inserted");
248 }
249 Err(err) => {
250 expected_upper = self.sync_to(err.current.into_option()).await;
251 if self.local.contains_key(key) {
252 return self.get_local(key).expect("checked above");
253 }
254 continue;
255 }
256 }
257 }
258 }
259
260 pub fn entries_local(&self) -> impl Iterator<Item = (&C::Key, &C::Val)> {
262 self.local.iter().map(|(key, val)| (key, &val.decoded_val))
263 }
264
265 pub async fn set(&mut self, key: &C::Key, value: Option<&C::Val>) {
269 while let Err(err) = self.try_set(key, value).await {
270 debug!("failed to set entry: {err} ... retrying");
271 }
272 }
273
274 pub async fn set_many(&mut self, entries: &[(&C::Key, Option<&C::Val>)]) {
279 while let Err(err) = self.try_set_many(entries).await {
280 debug!("failed to set entries: {err} ... retrying");
281 }
282 }
283
284 pub async fn try_set(&mut self, key: &C::Key, value: Option<&C::Val>) -> Result<(), Error> {
289 self.try_set_many(&[(key, value)]).await
290 }
291
292 pub async fn try_set_many(
298 &mut self,
299 entries: &[(&C::Key, Option<&C::Val>)],
300 ) -> Result<(), Error> {
301 let expected_upper = self.local_progress;
302 let mut updates = Vec::new();
303 let mut seen_keys = HashSet::new();
304
305 for (key, val) in entries {
306 if seen_keys.insert(key) {
308 if let Some(prev) = self.local.get(key) {
309 updates.push(((prev.encoded_key.clone(), prev.encoded_val.clone()), -1));
310 }
311 if let Some(val) = val {
312 updates.push((C::encode(key, val), 1));
313 }
314 }
315 }
316 consolidate(&mut updates);
317
318 let ret = self.compare_and_append(updates, expected_upper).await;
319 match ret {
320 Ok(new_upper) => {
321 self.sync_to(Some(new_upper)).await;
322 Ok(())
323 }
324 Err(err) => {
325 self.sync_to(err.current.clone().into_option()).await;
326 Err(Error::WriteConflict(err))
327 }
328 }
329 }
330
331 async fn compare_and_append<I>(
337 &mut self,
338 updates: I,
339 write_ts: u64,
340 ) -> Result<u64, UpperMismatch<u64>>
341 where
342 I: IntoIterator<Item = ((C::KeyCodec, C::ValCodec), i64)>,
347 {
348 let expected_upper = write_ts;
349 let new_upper = expected_upper + 1;
350 let updates = updates.into_iter().map(|((k, v), d)| ((k, v), write_ts, d));
351 self.write
352 .compare_and_append(
353 updates,
354 Antichain::from_elem(expected_upper),
355 Antichain::from_elem(new_upper),
356 )
357 .await
358 .expect("usage should be valid")?;
359
360 let downgrade_to = Antichain::from_elem(write_ts);
362
363 let opaque = *self.since_handle.opaque();
368 let ret = self
369 .since_handle
370 .compare_and_downgrade_since(&opaque, (&opaque, &downgrade_to))
371 .await;
372 if let Err(e) = ret {
373 soft_panic_or_log!("found opaque value {e}, but expected {opaque}");
374 }
375
376 Ok(new_upper)
377 }
378
379 pub async fn dangerous_compact_shard(
382 &self,
383 fuel: impl Fn() -> usize,
384 wait: impl Fn() -> Duration,
385 ) {
386 mz_persist_client::cli::admin::dangerous_force_compaction_and_break_pushdown(
387 &self.write,
388 fuel,
389 wait,
390 )
391 .await
392 }
393
394 pub async fn upgrade_version(&self) {
397 self.since_handle
398 .upgrade_version()
399 .await
400 .expect("invalid usage")
401 }
402}
403
404#[cfg(test)]
405mod tests {
406 use mz_ore::assert_none;
407 use mz_persist_client::cache::PersistClientCache;
408 use mz_persist_types::PersistLocation;
409 use mz_persist_types::codec_impls::StringSchema;
410
411 use super::*;
412
413 #[derive(Debug, PartialEq, Eq)]
414 struct TestCodec;
415
416 impl DurableCacheCodec for TestCodec {
417 type Key = String;
418 type Val = String;
419 type KeyCodec = String;
420 type ValCodec = String;
421
422 fn schemas() -> (
423 <Self::KeyCodec as Codec>::Schema,
424 <Self::ValCodec as Codec>::Schema,
425 ) {
426 (StringSchema, StringSchema)
427 }
428
429 fn encode(key: &Self::Key, val: &Self::Val) -> (Self::KeyCodec, Self::ValCodec) {
430 (key.clone(), val.clone())
431 }
432
433 fn decode(key: &Self::KeyCodec, val: &Self::ValCodec) -> (Self::Key, Self::Val) {
434 (key.clone(), val.clone())
435 }
436 }
437
438 #[mz_ore::test(tokio::test)]
439 #[cfg_attr(miri, ignore)]
440 async fn durable_cache() {
441 let persist = PersistClientCache::new_no_metrics();
442 let persist = persist
443 .open(PersistLocation::new_in_mem())
444 .await
445 .expect("location should be valid");
446 let shard_id = ShardId::new();
447
448 let mut cache0 = DurableCache::<TestCodec>::new(&persist, shard_id, "test1").await;
449 assert_none!(cache0.get_local(&"foo".into()));
450 assert_eq!(cache0.get(&"foo".into(), || "bar".into()).await, "bar");
451 assert_eq!(
452 cache0.entries_local().collect::<Vec<_>>(),
453 vec![(&"foo".into(), &"bar".into())]
454 );
455
456 cache0.set(&"k1".into(), Some(&"v1".into())).await;
457 cache0.set(&"k2".into(), Some(&"v2".into())).await;
458 assert_eq!(cache0.get_local(&"k1".into()), Some(&"v1".into()));
459 assert_eq!(cache0.get(&"k1".into(), || "v10".into()).await, &"v1");
460 assert_eq!(cache0.get_local(&"k2".into()), Some(&"v2".into()));
461 assert_eq!(cache0.get(&"k2".into(), || "v20".into()).await, &"v2");
462 assert_eq!(
463 cache0.entries_local().collect::<Vec<_>>(),
464 vec![
465 (&"foo".into(), &"bar".into()),
466 (&"k1".into(), &"v1".into()),
467 (&"k2".into(), &"v2".into())
468 ]
469 );
470
471 cache0.set(&"k1".into(), None).await;
472 assert_none!(cache0.get_local(&"k1".into()));
473 assert_eq!(
474 cache0.entries_local().collect::<Vec<_>>(),
475 vec![(&"foo".into(), &"bar".into()), (&"k2".into(), &"v2".into())]
476 );
477
478 cache0
479 .set_many(&[
480 (&"k1".into(), Some(&"v10".into())),
481 (&"k2".into(), None),
482 (&"k3".into(), None),
483 ])
484 .await;
485 assert_eq!(cache0.get_local(&"k1".into()), Some(&"v10".into()));
486 assert_none!(cache0.get_local(&"k2".into()));
487 assert_none!(cache0.get_local(&"k3".into()));
488 assert_eq!(
489 cache0.entries_local().collect::<Vec<_>>(),
490 vec![
491 (&"foo".into(), &"bar".into()),
492 (&"k1".into(), &"v10".into()),
493 ]
494 );
495
496 cache0
497 .set_many(&[
498 (&"k4".into(), Some(&"v40".into())),
499 (&"k4".into(), Some(&"v41".into())),
500 (&"k4".into(), Some(&"v42".into())),
501 (&"k5".into(), Some(&"v50".into())),
502 (&"k5".into(), Some(&"v51".into())),
503 (&"k5".into(), Some(&"v52".into())),
504 ])
505 .await;
506 assert_eq!(cache0.get_local(&"k4".into()), Some(&"v40".into()));
507 assert_eq!(cache0.get_local(&"k5".into()), Some(&"v50".into()));
508 assert_eq!(
509 cache0.entries_local().collect::<Vec<_>>(),
510 vec![
511 (&"foo".into(), &"bar".into()),
512 (&"k1".into(), &"v10".into()),
513 (&"k4".into(), &"v40".into()),
514 (&"k5".into(), &"v50".into()),
515 ]
516 );
517
518 let mut cache1 = DurableCache::<TestCodec>::new(&persist, shard_id, "test2").await;
519 assert_eq!(cache1.get(&"foo".into(), || panic!("boom")).await, "bar");
520 assert_eq!(cache1.get(&"k1".into(), || panic!("boom")).await, &"v10");
521 assert_none!(cache1.get_local(&"k2".into()));
522 assert_none!(cache1.get_local(&"k3".into()));
523 assert_eq!(
524 cache1.entries_local().collect::<Vec<_>>(),
525 vec![
526 (&"foo".into(), &"bar".into()),
527 (&"k1".into(), &"v10".into()),
528 (&"k4".into(), &"v40".into()),
529 (&"k5".into(), &"v50".into()),
530 ]
531 );
532
533 let fuel = || 131_072;
535 let wait = || Duration::from_millis(0);
536 cache1.dangerous_compact_shard(fuel, wait).await
537 }
538}