1use std::collections::BTreeMap;
13use std::fmt::{Debug, Formatter};
14use std::hash::Hash;
15use std::sync::Arc;
16use std::time::Duration;
17
18use differential_dataflow::consolidation::{consolidate, consolidate_updates};
19use mz_ore::collections::{AssociativeExt, HashSet};
20use mz_ore::soft_panic_or_log;
21use mz_persist_client::critical::SinceHandle;
22use mz_persist_client::error::UpperMismatch;
23use mz_persist_client::read::{ListenEvent, Subscribe};
24use mz_persist_client::write::WriteHandle;
25use mz_persist_client::{Diagnostics, PersistClient};
26use mz_persist_types::{Codec, ShardId};
27use timely::progress::Antichain;
28use tracing::debug;
29
30pub trait DurableCacheCodec: Debug + Eq {
31 type Key: Ord + Hash + Clone + Debug;
32 type Val: Eq + Debug;
33 type KeyCodec: Codec + Ord + Debug + Clone;
34 type ValCodec: Codec + Ord + Debug + Clone;
35
36 fn schemas() -> (
37 <Self::KeyCodec as Codec>::Schema,
38 <Self::ValCodec as Codec>::Schema,
39 );
40 fn encode(key: &Self::Key, val: &Self::Val) -> (Self::KeyCodec, Self::ValCodec);
41 fn decode(key: &Self::KeyCodec, val: &Self::ValCodec) -> (Self::Key, Self::Val);
42}
43
44#[derive(Debug)]
45pub enum Error {
46 WriteConflict(UpperMismatch<u64>),
47}
48
49impl std::fmt::Display for Error {
50 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
51 match self {
52 Error::WriteConflict(err) => write!(f, "{err}"),
53 }
54 }
55}
56
57#[derive(Debug, PartialEq, Eq)]
58struct LocalVal<C: DurableCacheCodec> {
59 encoded_key: C::KeyCodec,
60 decoded_val: C::Val,
61 encoded_val: C::ValCodec,
62}
63
64#[derive(Debug)]
65pub struct DurableCache<C: DurableCacheCodec> {
66 since_handle: SinceHandle<C::KeyCodec, C::ValCodec, u64, i64, i64>,
67 write: WriteHandle<C::KeyCodec, C::ValCodec, u64, i64>,
68 subscribe: Subscribe<C::KeyCodec, C::ValCodec, u64, i64>,
69
70 local: BTreeMap<C::Key, LocalVal<C>>,
71 local_progress: u64,
72}
73
74const USE_CRITICAL_SINCE: bool = true;
75
76impl<C: DurableCacheCodec> DurableCache<C> {
77 pub async fn new(persist: &PersistClient, shard_id: ShardId, purpose: &str) -> Self {
79 let diagnostics = Diagnostics {
80 shard_name: format!("{purpose}_cache"),
81 handle_purpose: format!("durable persist cache: {purpose}"),
82 };
83 let since_handle = persist
84 .open_critical_since(
85 shard_id,
86 PersistClient::CONTROLLER_CRITICAL_SINCE,
89 diagnostics.clone(),
90 )
91 .await
92 .expect("invalid usage");
93 let (key_schema, val_schema) = C::schemas();
94 let (mut write, read) = persist
95 .open(
96 shard_id,
97 Arc::new(key_schema),
98 Arc::new(val_schema),
99 diagnostics,
100 USE_CRITICAL_SINCE,
101 )
102 .await
103 .expect("shard codecs should not change");
104 let res = write
106 .compare_and_append_batch(&mut [], Antichain::from_elem(0), Antichain::from_elem(1))
107 .await
108 .expect("usage was valid");
109 match res {
110 Ok(()) => {}
112 Err(UpperMismatch { .. }) => {}
114 }
115
116 let as_of = read.since().clone();
117 let subscribe = read
118 .subscribe(as_of)
119 .await
120 .expect("capability should be held at this since");
121 let mut ret = DurableCache {
122 since_handle,
123 write,
124 subscribe,
125 local: BTreeMap::new(),
126 local_progress: 0,
127 };
128 ret.sync_to(ret.write.upper().as_option().copied()).await;
129 ret
130 }
131
132 async fn sync_to(&mut self, progress: Option<u64>) -> u64 {
133 let progress = progress.expect("cache shard should not be closed");
134 let mut updates: BTreeMap<_, Vec<_>> = BTreeMap::new();
135
136 while self.local_progress < progress {
137 let events = self.subscribe.fetch_next().await;
138 for event in events {
139 match event {
140 ListenEvent::Updates(batch_updates) => {
141 debug!("syncing updates {batch_updates:?}");
142 for update in batch_updates {
143 updates.entry(update.1).or_default().push(update);
144 }
145 }
146 ListenEvent::Progress(x) => {
147 debug!("synced up to {x:?}");
148 self.local_progress =
149 x.into_option().expect("cache shard should not be closed");
150 while let Some((ts, mut updates)) = updates.pop_first() {
153 assert!(
154 ts < self.local_progress,
155 "expected {} < {}",
156 ts,
157 self.local_progress
158 );
159 assert!(
160 updates.iter().all(|(_, update_ts, _)| ts == *update_ts),
161 "all updates should be for time {ts}, updates: {updates:?}"
162 );
163
164 consolidate_updates(&mut updates);
165 updates.sort_by(|(_, _, d1), (_, _, d2)| d1.cmp(d2));
166 for ((k, v), t, d) in updates {
167 let encoded_key = k.unwrap();
168 let encoded_val = v.unwrap();
169 let (decoded_key, decoded_val) =
170 C::decode(&encoded_key, &encoded_val);
171 let val = LocalVal {
172 encoded_key,
173 decoded_val,
174 encoded_val,
175 };
176
177 if d == 1 {
178 self.local.expect_insert(
179 decoded_key,
180 val,
181 "duplicate cache entry",
182 );
183 } else if d == -1 {
184 let prev = self
185 .local
186 .expect_remove(&decoded_key, "entry does not exist");
187 assert_eq!(
188 val, prev,
189 "removed val does not match expected val"
190 );
191 } else {
192 panic!(
193 "unexpected diff: (({:?}, {:?}), {}, {})",
194 decoded_key, val.decoded_val, t, d
195 );
196 }
197 }
198 }
199 }
200 }
201 }
202 }
203 assert_eq!(updates, BTreeMap::new(), "all updates should be applied");
204 progress
205 }
206
207 pub fn get_local(&self, key: &C::Key) -> Option<&C::Val> {
210 self.local.get(key).map(|val| &val.decoded_val)
211 }
212
213 pub async fn get(&mut self, key: &C::Key, val_fn: impl FnOnce() -> C::Val) -> &C::Val {
217 if self.local.contains_key(key) {
221 return self.get_local(key).expect("checked above");
222 }
223
224 self.sync_to(self.write.shared_upper().into_option()).await;
227 if self.local.contains_key(key) {
228 return self.get_local(key).expect("checked above");
229 }
230
231 let val = val_fn();
233 let mut expected_upper = self.local_progress;
234 let update = (C::encode(key, &val), 1);
235 loop {
236 let ret = self
237 .compare_and_append([update.clone()], expected_upper)
238 .await;
239 match ret {
240 Ok(new_upper) => {
241 self.sync_to(Some(new_upper)).await;
242 return self.get_local(key).expect("just inserted");
243 }
244 Err(err) => {
245 expected_upper = self.sync_to(err.current.into_option()).await;
246 if self.local.contains_key(key) {
247 return self.get_local(key).expect("checked above");
248 }
249 continue;
250 }
251 }
252 }
253 }
254
255 pub fn entries_local(&self) -> impl Iterator<Item = (&C::Key, &C::Val)> {
257 self.local.iter().map(|(key, val)| (key, &val.decoded_val))
258 }
259
260 pub async fn set(&mut self, key: &C::Key, value: Option<&C::Val>) {
264 while let Err(err) = self.try_set(key, value).await {
265 debug!("failed to set entry: {err} ... retrying");
266 }
267 }
268
269 pub async fn set_many(&mut self, entries: &[(&C::Key, Option<&C::Val>)]) {
274 while let Err(err) = self.try_set_many(entries).await {
275 debug!("failed to set entries: {err} ... retrying");
276 }
277 }
278
279 pub async fn try_set(&mut self, key: &C::Key, value: Option<&C::Val>) -> Result<(), Error> {
284 self.try_set_many(&[(key, value)]).await
285 }
286
287 pub async fn try_set_many(
293 &mut self,
294 entries: &[(&C::Key, Option<&C::Val>)],
295 ) -> Result<(), Error> {
296 let expected_upper = self.local_progress;
297 let mut updates = Vec::new();
298 let mut seen_keys = HashSet::new();
299
300 for (key, val) in entries {
301 if seen_keys.insert(key) {
303 if let Some(prev) = self.local.get(key) {
304 updates.push(((prev.encoded_key.clone(), prev.encoded_val.clone()), -1));
305 }
306 if let Some(val) = val {
307 updates.push((C::encode(key, val), 1));
308 }
309 }
310 }
311 consolidate(&mut updates);
312
313 let ret = self.compare_and_append(updates, expected_upper).await;
314 match ret {
315 Ok(new_upper) => {
316 self.sync_to(Some(new_upper)).await;
317 Ok(())
318 }
319 Err(err) => {
320 self.sync_to(err.current.clone().into_option()).await;
321 Err(Error::WriteConflict(err))
322 }
323 }
324 }
325
326 async fn compare_and_append<I>(
332 &mut self,
333 updates: I,
334 write_ts: u64,
335 ) -> Result<u64, UpperMismatch<u64>>
336 where
337 I: IntoIterator<Item = ((C::KeyCodec, C::ValCodec), i64)>,
342 {
343 let expected_upper = write_ts;
344 let new_upper = expected_upper + 1;
345 let updates = updates.into_iter().map(|((k, v), d)| ((k, v), write_ts, d));
346 self.write
347 .compare_and_append(
348 updates,
349 Antichain::from_elem(expected_upper),
350 Antichain::from_elem(new_upper),
351 )
352 .await
353 .expect("usage should be valid")?;
354
355 let downgrade_to = Antichain::from_elem(write_ts);
357
358 let opaque = *self.since_handle.opaque();
363 let ret = self
364 .since_handle
365 .compare_and_downgrade_since(&opaque, (&opaque, &downgrade_to))
366 .await;
367 if let Err(e) = ret {
368 soft_panic_or_log!("found opaque value {e}, but expected {opaque}");
369 }
370
371 Ok(new_upper)
372 }
373
374 pub async fn dangerous_compact_shard(
377 &self,
378 fuel: impl Fn() -> usize,
379 wait: impl Fn() -> Duration,
380 ) {
381 mz_persist_client::cli::admin::dangerous_force_compaction_and_break_pushdown(
382 &self.write,
383 fuel,
384 wait,
385 )
386 .await
387 }
388}
389
390#[cfg(test)]
391mod tests {
392 use mz_ore::assert_none;
393 use mz_persist_client::cache::PersistClientCache;
394 use mz_persist_types::PersistLocation;
395 use mz_persist_types::codec_impls::StringSchema;
396
397 use super::*;
398
399 #[derive(Debug, PartialEq, Eq)]
400 struct TestCodec;
401
402 impl DurableCacheCodec for TestCodec {
403 type Key = String;
404 type Val = String;
405 type KeyCodec = String;
406 type ValCodec = String;
407
408 fn schemas() -> (
409 <Self::KeyCodec as Codec>::Schema,
410 <Self::ValCodec as Codec>::Schema,
411 ) {
412 (StringSchema, StringSchema)
413 }
414
415 fn encode(key: &Self::Key, val: &Self::Val) -> (Self::KeyCodec, Self::ValCodec) {
416 (key.clone(), val.clone())
417 }
418
419 fn decode(key: &Self::KeyCodec, val: &Self::ValCodec) -> (Self::Key, Self::Val) {
420 (key.clone(), val.clone())
421 }
422 }
423
424 #[mz_ore::test(tokio::test)]
425 #[cfg_attr(miri, ignore)]
426 async fn durable_cache() {
427 let persist = PersistClientCache::new_no_metrics();
428 let persist = persist
429 .open(PersistLocation::new_in_mem())
430 .await
431 .expect("location should be valid");
432 let shard_id = ShardId::new();
433
434 let mut cache0 = DurableCache::<TestCodec>::new(&persist, shard_id, "test1").await;
435 assert_none!(cache0.get_local(&"foo".into()));
436 assert_eq!(cache0.get(&"foo".into(), || "bar".into()).await, "bar");
437 assert_eq!(
438 cache0.entries_local().collect::<Vec<_>>(),
439 vec![(&"foo".into(), &"bar".into())]
440 );
441
442 cache0.set(&"k1".into(), Some(&"v1".into())).await;
443 cache0.set(&"k2".into(), Some(&"v2".into())).await;
444 assert_eq!(cache0.get_local(&"k1".into()), Some(&"v1".into()));
445 assert_eq!(cache0.get(&"k1".into(), || "v10".into()).await, &"v1");
446 assert_eq!(cache0.get_local(&"k2".into()), Some(&"v2".into()));
447 assert_eq!(cache0.get(&"k2".into(), || "v20".into()).await, &"v2");
448 assert_eq!(
449 cache0.entries_local().collect::<Vec<_>>(),
450 vec![
451 (&"foo".into(), &"bar".into()),
452 (&"k1".into(), &"v1".into()),
453 (&"k2".into(), &"v2".into())
454 ]
455 );
456
457 cache0.set(&"k1".into(), None).await;
458 assert_none!(cache0.get_local(&"k1".into()));
459 assert_eq!(
460 cache0.entries_local().collect::<Vec<_>>(),
461 vec![(&"foo".into(), &"bar".into()), (&"k2".into(), &"v2".into())]
462 );
463
464 cache0
465 .set_many(&[
466 (&"k1".into(), Some(&"v10".into())),
467 (&"k2".into(), None),
468 (&"k3".into(), None),
469 ])
470 .await;
471 assert_eq!(cache0.get_local(&"k1".into()), Some(&"v10".into()));
472 assert_none!(cache0.get_local(&"k2".into()));
473 assert_none!(cache0.get_local(&"k3".into()));
474 assert_eq!(
475 cache0.entries_local().collect::<Vec<_>>(),
476 vec![
477 (&"foo".into(), &"bar".into()),
478 (&"k1".into(), &"v10".into()),
479 ]
480 );
481
482 cache0
483 .set_many(&[
484 (&"k4".into(), Some(&"v40".into())),
485 (&"k4".into(), Some(&"v41".into())),
486 (&"k4".into(), Some(&"v42".into())),
487 (&"k5".into(), Some(&"v50".into())),
488 (&"k5".into(), Some(&"v51".into())),
489 (&"k5".into(), Some(&"v52".into())),
490 ])
491 .await;
492 assert_eq!(cache0.get_local(&"k4".into()), Some(&"v40".into()));
493 assert_eq!(cache0.get_local(&"k5".into()), Some(&"v50".into()));
494 assert_eq!(
495 cache0.entries_local().collect::<Vec<_>>(),
496 vec![
497 (&"foo".into(), &"bar".into()),
498 (&"k1".into(), &"v10".into()),
499 (&"k4".into(), &"v40".into()),
500 (&"k5".into(), &"v50".into()),
501 ]
502 );
503
504 let mut cache1 = DurableCache::<TestCodec>::new(&persist, shard_id, "test2").await;
505 assert_eq!(cache1.get(&"foo".into(), || panic!("boom")).await, "bar");
506 assert_eq!(cache1.get(&"k1".into(), || panic!("boom")).await, &"v10");
507 assert_none!(cache1.get_local(&"k2".into()));
508 assert_none!(cache1.get_local(&"k3".into()));
509 assert_eq!(
510 cache1.entries_local().collect::<Vec<_>>(),
511 vec![
512 (&"foo".into(), &"bar".into()),
513 (&"k1".into(), &"v10".into()),
514 (&"k4".into(), &"v40".into()),
515 (&"k5".into(), &"v50".into()),
516 ]
517 );
518
519 let fuel = || 131_072;
521 let wait = || Duration::from_millis(0);
522 cache1.dangerous_compact_shard(fuel, wait).await
523 }
524}