1use std::sync::{Arc, Mutex};
13
14use async_trait::async_trait;
15use bytes::Bytes;
16use mz_dyncfg::{Config, ConfigSet};
17use mz_ore::bytes::SegmentedBytes;
18use mz_ore::cast::CastFrom;
19use mz_persist::location::{Blob, BlobMetadata, ExternalError};
20
21use crate::cfg::PersistConfig;
22use crate::internal::metrics::Metrics;
23
24#[derive(Debug)]
26pub struct BlobMemCache {
27 cfg: Arc<ConfigSet>,
29 num_workers: usize,
31 metrics: Arc<Metrics>,
32 cache: Mutex<lru::Lru<String, SegmentedBytes>>,
33 blob: Arc<dyn Blob>,
34}
35
36pub(crate) const BLOB_CACHE_MEM_LIMIT_BYTES: Config<usize> = Config::new(
37 "persist_blob_cache_mem_limit_bytes",
38 128 * 1024 * 1024,
40 "Capacity of in-mem blob cache in bytes (Materialize).",
41);
42
43pub(crate) const BLOB_CACHE_SCALE_WITH_THREADS: Config<bool> = Config::new(
44 "persist_blob_cache_scale_with_threads",
45 false,
46 "Whether or not the size of the in-mem blob cache scales with the number of threads in the current process (Materialize).",
47);
48
49pub(crate) const BLOB_CACHE_SCALE_FACTOR_BYTES: Config<usize> = Config::new(
50 "persist_blob_cache_scale_factor_bytes",
51 32 * 1024 * 1024,
53 "Scale factor for the in-mem blob cache, in bytes, if scaling with threads (Materialize).",
54);
55
56impl BlobMemCache {
57 pub fn new(cfg: &PersistConfig, metrics: Arc<Metrics>, blob: Arc<dyn Blob>) -> Arc<dyn Blob> {
58 let eviction_metrics = Arc::clone(&metrics);
59 let capacity_bytes =
60 BlobMemCache::get_capacity_bytes(&cfg.configs, cfg.isolated_runtime_worker_threads);
61 let cache = lru::Lru::new(capacity_bytes, move |_, _, _| {
62 eviction_metrics.blob_cache_mem.evictions.inc()
63 });
64 let blob = BlobMemCache {
65 cfg: Arc::clone(&cfg.configs),
66 num_workers: cfg.isolated_runtime_worker_threads,
67 metrics,
68 cache: Mutex::new(cache),
69 blob,
70 };
71 Arc::new(blob)
72 }
73
74 fn resize_and_update_size_metrics(&self, cache: &mut lru::Lru<String, SegmentedBytes>) {
75 let capacity_bytes = BlobMemCache::get_capacity_bytes(&self.cfg, self.num_workers);
76 cache.update_capacity(capacity_bytes);
77 self.metrics
78 .blob_cache_mem
79 .size_blobs
80 .set(u64::cast_from(cache.entry_count()));
81 self.metrics
82 .blob_cache_mem
83 .size_bytes
84 .set(u64::cast_from(cache.entry_weight()));
85 }
86
87 fn get_capacity_bytes(cfg: &Arc<ConfigSet>, num_workers: usize) -> usize {
88 let static_size = BLOB_CACHE_MEM_LIMIT_BYTES.get(cfg);
91
92 if BLOB_CACHE_SCALE_WITH_THREADS.get(cfg) {
93 let per_thread_const = BLOB_CACHE_SCALE_FACTOR_BYTES.get(cfg);
94 let dynamic_size = num_workers.saturating_mul(per_thread_const);
95 std::cmp::max(dynamic_size, static_size)
96 } else {
97 static_size
98 }
99 }
100}
101
102#[async_trait]
103impl Blob for BlobMemCache {
104 async fn get(&self, key: &str) -> Result<Option<SegmentedBytes>, ExternalError> {
105 if let Some((_, cached_value)) = self.cache.lock().expect("lock poisoned").get(key) {
113 self.metrics.blob_cache_mem.hits_blobs.inc();
114 self.metrics
115 .blob_cache_mem
116 .hits_bytes
117 .inc_by(u64::cast_from(cached_value.len()));
118 return Ok(Some(cached_value.clone()));
119 }
120
121 let res = self.blob.get(key).await?;
122 if let Some(blob) = res.as_ref() {
123 let mut cache = self.cache.lock().expect("lock poisoned");
127 if blob.len() <= cache.capacity() {
131 cache.insert(key.to_owned(), blob.clone(), blob.len());
132 self.resize_and_update_size_metrics(&mut cache);
133 }
134 }
135 Ok(res)
136 }
137
138 async fn list_keys_and_metadata(
139 &self,
140 key_prefix: &str,
141 f: &mut (dyn FnMut(BlobMetadata) + Send + Sync),
142 ) -> Result<(), ExternalError> {
143 self.blob.list_keys_and_metadata(key_prefix, f).await
144 }
145
146 async fn set(&self, key: &str, value: Bytes) -> Result<(), ExternalError> {
147 let () = self.blob.set(key, value.clone()).await?;
148 let weight = value.len();
149 let mut cache = self.cache.lock().expect("lock poisoned");
150 if weight <= cache.capacity() {
154 cache.insert(key.to_owned(), SegmentedBytes::from(value), weight);
155 self.resize_and_update_size_metrics(&mut cache);
156 }
157 Ok(())
158 }
159
160 async fn delete(&self, key: &str) -> Result<Option<usize>, ExternalError> {
161 let res = self.blob.delete(key).await;
162 let mut cache = self.cache.lock().expect("lock poisoned");
163 cache.remove(key);
164 self.resize_and_update_size_metrics(&mut cache);
165 res
166 }
167
168 async fn restore(&self, key: &str) -> Result<(), ExternalError> {
169 self.blob.restore(key).await
170 }
171}
172
173mod lru {
174 use std::borrow::Borrow;
175 use std::collections::BTreeMap;
176 use std::hash::Hash;
177
178 use mz_ore::collections::HashMap;
179
180 #[derive(Debug, PartialEq, Eq, PartialOrd, Ord)]
181 pub struct Weight(usize);
182
183 #[derive(Debug, Default, Clone, PartialEq, Eq, PartialOrd, Ord)]
184 pub struct Time(usize);
185
186 pub struct Lru<K, V> {
191 evict_fn: Box<dyn Fn(K, V, usize) + Send>,
192 capacity: Weight,
193
194 next_time: Time,
195 entries: HashMap<K, (V, Weight, Time)>,
196 by_time: BTreeMap<Time, K>,
197 total_weight: Weight,
198 }
199
200 impl<K: Hash + Eq + Clone, V> Lru<K, V> {
201 pub fn new<F>(capacity: usize, evict_fn: F) -> Self
207 where
208 F: Fn(K, V, usize) + Send + 'static,
209 {
210 Lru {
211 evict_fn: Box::new(evict_fn),
212 capacity: Weight(capacity),
213 next_time: Time::default(),
214 entries: HashMap::new(),
215 by_time: BTreeMap::new(),
216 total_weight: Weight(0),
217 }
218 }
219
220 pub fn capacity(&self) -> usize {
222 self.capacity.0
223 }
224
225 pub fn entry_count(&self) -> usize {
227 debug_assert_eq!(self.entries.len(), self.by_time.len());
228 self.entries.len()
229 }
230
231 pub fn entry_weight(&self) -> usize {
233 self.total_weight.0
234 }
235
236 pub fn update_capacity(&mut self, capacity: usize) {
239 self.capacity = Weight(capacity);
240 self.resize();
241 assert!(self.total_weight <= self.capacity);
242
243 #[cfg(test)]
246 self.validate();
247 }
248
249 pub fn get<Q>(&mut self, key: &Q) -> Option<(&K, &V)>
252 where
253 K: Borrow<Q>,
254 Q: Hash + Eq + ?Sized,
255 {
256 {
257 let (key, val, weight) = self.remove(key)?;
258 self.insert_not_exists(key, val, Weight(weight));
259 }
260 let (key, (val, _, _)) = self
261 .entries
262 .get_key_value(key)
263 .expect("internal lru invariant violated");
264
265 #[cfg(test)]
268 self.validate();
269
270 Some((key, val))
271 }
272
273 pub fn insert(&mut self, key: K, val: V, weight: usize) {
279 let _ = self.remove(&key);
280 self.insert_not_exists(key, val, Weight(weight));
281
282 #[cfg(test)]
285 self.validate();
286 }
287
288 pub fn remove<Q>(&mut self, k: &Q) -> Option<(K, V, usize)>
292 where
293 K: Borrow<Q>,
294 Q: Hash + Eq + ?Sized,
295 {
296 let (_, _, time) = self.entries.get(k)?;
297 let (key, val, weight) = self.remove_exists(time.clone());
298
299 #[cfg(test)]
302 self.validate();
303
304 Some((key, val, weight.0))
305 }
306
307 #[allow(dead_code)]
310 pub(crate) fn iter(&self) -> impl Iterator<Item = (&K, &V, usize)> {
311 self.by_time.iter().rev().map(|(_, key)| {
312 let (val, _, weight) = self
313 .entries
314 .get(key)
315 .expect("internal lru invariant violated");
316 (key, val, weight.0)
317 })
318 }
319
320 fn insert_not_exists(&mut self, key: K, val: V, weight: Weight) {
321 let time = self.next_time.clone();
322 self.next_time.0 += 1;
323
324 self.total_weight.0 = self
325 .total_weight
326 .0
327 .checked_add(weight.0)
328 .expect("weight overflow");
329 assert!(
330 self.entries
331 .insert(key.clone(), (val, weight, time.clone()))
332 .is_none()
333 );
334 assert!(self.by_time.insert(time, key).is_none());
335 self.resize();
336 }
337
338 fn remove_exists(&mut self, time: Time) -> (K, V, Weight) {
339 let key = self
340 .by_time
341 .remove(&time)
342 .expect("internal list invariant violated");
343 let (val, weight, _time) = self
344 .entries
345 .remove(&key)
346 .expect("internal list invariant violated");
347 self.total_weight.0 = self
348 .total_weight
349 .0
350 .checked_sub(weight.0)
351 .expect("internal lru invariant violated");
352
353 (key, val, weight)
354 }
355
356 fn resize(&mut self) {
357 while self.total_weight > self.capacity {
358 let (time, _) = self
359 .by_time
360 .first_key_value()
361 .expect("internal lru invariant violated");
362 let (key, val, weight) = self.remove_exists(time.clone());
363 (self.evict_fn)(key, val, weight.0);
364 }
365 }
366
367 #[cfg(test)]
372 pub(crate) fn validate(&self) {
373 assert!(self.total_weight <= self.capacity);
374
375 let mut count = 0;
376 let mut weight = 0;
377 for (time, key) in self.by_time.iter() {
378 let (_val, w, t) = self
379 .entries
380 .get(key)
381 .expect("internal lru invariant violated");
382 count += 1;
383 weight += w.0;
384 assert_eq!(time, t);
385 }
386 assert_eq!(count, self.by_time.len());
387 assert_eq!(weight, self.total_weight.0);
388 }
389 }
390
391 impl<K: std::fmt::Debug, V> std::fmt::Debug for Lru<K, V> {
392 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
393 let Lru {
394 evict_fn: _,
395 capacity,
396 next_time,
397 entries: _,
398 by_time,
399 total_weight,
400 } = self;
401 f.debug_struct("Lru")
402 .field("capacity", &capacity)
403 .field("total_weight", &total_weight)
404 .field("next_time", &next_time)
405 .field("by_time", &by_time)
406 .finish_non_exhaustive()
407 }
408 }
409}
410
411#[cfg(test)]
412mod tests {
413 use mz_ore::assert_none;
414 use proptest::arbitrary::any;
415 use proptest::proptest;
416 use proptest_derive::Arbitrary;
417
418 use super::lru::*;
419
420 #[derive(Debug, Arbitrary)]
421 enum LruOp {
422 Get { key: u8 },
423 Insert { key: u8, weight: u8 },
424 Remove { key: u8 },
425 }
426
427 fn prop_testcase(ops: Vec<LruOp>) {
428 let capacity = usize::from(u8::MAX / 2) * usize::from(u8::MAX / 2) / 2;
432 let mut cache = Lru::new(capacity, |_, _, _| {});
433 for op in ops {
434 match op {
435 LruOp::Get { key } => {
436 let _ = cache.get(&key);
437 }
438 LruOp::Insert { key, weight } => {
439 cache.insert(key, (), usize::from(weight));
440 }
441 LruOp::Remove { key } => {
442 let _ = cache.remove(&key);
443 }
444 }
445 cache.validate();
446 }
447 }
448
449 #[mz_ore::test]
450 #[cfg_attr(miri, ignore)] fn lru_cache_prop() {
452 proptest!(|(state in proptest::collection::vec(any::<LruOp>(), 0..100))| prop_testcase(state));
453 }
454
455 impl Lru<&'static str, ()> {
456 fn keys(&self) -> Vec<&'static str> {
457 self.iter().map(|(k, _, _)| *k).collect()
458 }
459 }
460
461 #[mz_ore::test]
462 #[cfg_attr(miri, ignore)]
463 fn lru_cache_usage() {
464 let mut cache = Lru::<&'static str, ()>::new(3, |_, _, _| {});
465
466 assert_eq!(cache.entry_count(), 0);
468 assert_eq!(cache.entry_weight(), 0);
469
470 cache.insert("a", (), 2);
472 assert_eq!(cache.entry_count(), 1);
473 assert_eq!(cache.entry_weight(), 2);
474 assert_eq!(cache.keys(), &["a"]);
475
476 cache.insert("b", (), 2);
478 assert_eq!(cache.entry_count(), 1);
479 assert_eq!(cache.entry_weight(), 2);
480 assert_eq!(cache.keys(), &["b"]);
481
482 cache.insert("c", (), 1);
484 assert_eq!(cache.entry_count(), 2);
485 assert_eq!(cache.entry_weight(), 3);
486 assert_eq!(cache.keys(), &["c", "b"]);
487
488 cache.insert("d", (), 1);
490 cache.insert("e", (), 1);
491 assert_eq!(cache.entry_count(), 3);
492 assert_eq!(cache.entry_weight(), 3);
493 assert_eq!(cache.keys(), &["e", "d", "c"]);
494
495 cache.get("e");
497 assert_eq!(cache.entry_count(), 3);
498 assert_eq!(cache.entry_weight(), 3);
499 assert_eq!(cache.keys(), &["e", "d", "c"]);
500
501 cache.get("c");
503 assert_eq!(cache.entry_count(), 3);
504 assert_eq!(cache.entry_weight(), 3);
505 assert_eq!(cache.keys(), &["c", "e", "d"]);
506
507 cache.get("e");
509 assert_eq!(cache.entry_count(), 3);
510 assert_eq!(cache.entry_weight(), 3);
511 assert_eq!(cache.keys(), &["e", "c", "d"]);
512
513 cache.get("f");
515 assert_eq!(cache.entry_count(), 3);
516 assert_eq!(cache.entry_weight(), 3);
517 assert_eq!(cache.keys(), &["e", "c", "d"]);
518
519 assert!(cache.remove("c").is_some());
521 assert_eq!(cache.entry_count(), 2);
522 assert_eq!(cache.entry_weight(), 2);
523 assert_eq!(cache.keys(), &["e", "d"]);
524
525 assert_none!(cache.remove("f"));
527 assert_eq!(cache.entry_count(), 2);
528 assert_eq!(cache.entry_weight(), 2);
529 assert_eq!(cache.keys(), &["e", "d"]);
530
531 cache.insert("f", (), 3);
533 assert_eq!(cache.entry_count(), 1);
534 assert_eq!(cache.entry_weight(), 3);
535 assert_eq!(cache.keys(), &["f"]);
536
537 cache.insert("g", (), 4);
540 assert_eq!(cache.entry_count(), 0);
541 assert_eq!(cache.entry_weight(), 0);
542
543 cache.insert("h", (), 2);
545 cache.insert("i", (), 1);
546 cache.update_capacity(4);
547 cache.insert("j", (), 1);
548 assert_eq!(cache.entry_count(), 3);
549 assert_eq!(cache.entry_weight(), 4);
550 assert_eq!(cache.keys(), &["j", "i", "h"]);
551
552 cache.update_capacity(2);
554 assert_eq!(cache.entry_count(), 2);
555 assert_eq!(cache.entry_weight(), 2);
556 assert_eq!(cache.keys(), &["j", "i"]);
557 }
558}