moka/sync/
value_initializer.rs

1use parking_lot::RwLock;
2use std::{
3    any::{Any, TypeId},
4    fmt,
5    hash::{BuildHasher, Hash},
6    sync::Arc,
7};
8
9use crate::{
10    common::concurrent::arc::MiniArc,
11    ops::compute::{CompResult, Op},
12    Entry,
13};
14
15use super::{Cache, ComputeNone, OptionallyNone};
16
17const WAITER_MAP_NUM_SEGMENTS: usize = 64;
18
19type ErrorObject = Arc<dyn Any + Send + Sync + 'static>;
20
21// type WaiterValue<V> = Option<Result<V, ErrorObject>>;
22enum WaiterValue<V> {
23    Computing,
24    Ready(Result<V, ErrorObject>),
25    ReadyNone,
26    // https://github.com/moka-rs/moka/issues/43
27    InitClosurePanicked,
28}
29
30impl<V> fmt::Debug for WaiterValue<V> {
31    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
32        match self {
33            WaiterValue::Computing => write!(f, "Computing"),
34            WaiterValue::Ready(_) => write!(f, "Ready"),
35            WaiterValue::ReadyNone => write!(f, "ReadyNone"),
36            WaiterValue::InitClosurePanicked => write!(f, "InitFuturePanicked"),
37        }
38    }
39}
40
41type Waiter<V> = MiniArc<RwLock<WaiterValue<V>>>;
42
43pub(crate) enum InitResult<V, E> {
44    Initialized(V),
45    ReadExisting(V),
46    InitErr(Arc<E>),
47}
48
49pub(crate) struct ValueInitializer<K, V, S> {
50    // TypeId is the type ID of the concrete error type of generic type E in the
51    // try_get_with method. We use the type ID as a part of the key to ensure that
52    // we can always downcast the trait object ErrorObject (in Waiter<V>) into
53    // its concrete type.
54    waiters: crate::cht::SegmentedHashMap<(Arc<K>, TypeId), Waiter<V>, S>,
55}
56
57impl<K, V, S> ValueInitializer<K, V, S>
58where
59    K: Hash + Eq + Send + Sync + 'static,
60    V: Clone + Send + Sync + 'static,
61    S: BuildHasher + Clone + Send + Sync + 'static,
62{
63    pub(crate) fn with_hasher(hasher: S) -> Self {
64        Self {
65            waiters: crate::cht::SegmentedHashMap::with_num_segments_and_hasher(
66                WAITER_MAP_NUM_SEGMENTS,
67                hasher,
68            ),
69        }
70    }
71
72    /// # Panics
73    /// Panics if the `init` closure has been panicked.
74    pub(crate) fn try_init_or_read<O, E>(
75        &self,
76        key: &Arc<K>,
77        type_id: TypeId,
78        // Closure to get an existing value from cache.
79        mut get: impl FnMut() -> Option<V>,
80        // Closure to initialize a new value.
81        init: impl FnOnce() -> O,
82        // Closure to insert a new value into cache.
83        mut insert: impl FnMut(V),
84        // Function to convert a value O, returned from the init future, into
85        // Result<V, E>.
86        post_init: fn(O) -> Result<V, E>,
87    ) -> InitResult<V, E>
88    where
89        E: Send + Sync + 'static,
90    {
91        use std::panic::{catch_unwind, resume_unwind, AssertUnwindSafe};
92        use InitResult::{InitErr, ReadExisting};
93
94        const MAX_RETRIES: usize = 200;
95        let mut retries = 0;
96
97        let (w_key, w_hash) = self.waiter_key_hash(key, type_id);
98
99        let waiter = MiniArc::new(RwLock::new(WaiterValue::Computing));
100        let mut lock = waiter.write();
101
102        loop {
103            let Some(existing_waiter) = self.try_insert_waiter(w_key.clone(), w_hash, &waiter)
104            else {
105                // Inserted.
106                break;
107            };
108
109            // Somebody else's waiter already exists, so wait for its result to become available.
110            let waiter_result = existing_waiter.read();
111            match &*waiter_result {
112                WaiterValue::Ready(Ok(value)) => return ReadExisting(value.clone()),
113                WaiterValue::Ready(Err(e)) => return InitErr(Arc::clone(e).downcast().unwrap()),
114                // Somebody else's init closure has been panicked.
115                WaiterValue::InitClosurePanicked => {
116                    retries += 1;
117                    assert!(
118                        retries < MAX_RETRIES,
119                        "Too many retries. Tried to read the return value from the `init` \
120                        closure but failed {retries} times. Maybe the `init` kept panicking?"
121                    );
122
123                    // Retry from the beginning.
124                    continue;
125                }
126                // Unexpected state.
127                s @ (WaiterValue::Computing | WaiterValue::ReadyNone) => panic!(
128                    "Got unexpected state `{s:?}` after resolving `init` future. \
129                    This might be a bug in Moka"
130                ),
131            }
132        }
133
134        // Our waiter was inserted.
135
136        // Check if the value has already been inserted by other thread.
137        if let Some(value) = get() {
138            // Yes. Set the waiter value, remove our waiter, and return
139            // the existing value.
140            *lock = WaiterValue::Ready(Ok(value.clone()));
141            self.remove_waiter(w_key, w_hash);
142            return InitResult::ReadExisting(value);
143        }
144
145        // The value still does note exist. Let's evaluate the init
146        // closure. Catching panic is safe here as we do not try to
147        // evaluate the closure again.
148        match catch_unwind(AssertUnwindSafe(init)) {
149            // Evaluated.
150            Ok(value) => {
151                let init_res = match post_init(value) {
152                    Ok(value) => {
153                        insert(value.clone());
154                        *lock = WaiterValue::Ready(Ok(value.clone()));
155                        InitResult::Initialized(value)
156                    }
157                    Err(e) => {
158                        let err: ErrorObject = Arc::new(e);
159                        *lock = WaiterValue::Ready(Err(Arc::clone(&err)));
160                        InitResult::InitErr(err.downcast().unwrap())
161                    }
162                };
163                self.remove_waiter(w_key, w_hash);
164                init_res
165            }
166            // Panicked.
167            Err(payload) => {
168                *lock = WaiterValue::InitClosurePanicked;
169                // Remove the waiter so that others can retry.
170                self.remove_waiter(w_key, w_hash);
171                resume_unwind(payload);
172            }
173        }
174        // The write lock will be unlocked here.
175    }
176
177    /// # Panics
178    /// Panics if the `init` closure has been panicked.
179    pub(crate) fn try_compute<F, O, E>(
180        &self,
181        c_key: Arc<K>,
182        c_hash: u64,
183        cache: &Cache<K, V, S>,
184        f: F,
185        post_init: fn(O) -> Result<Op<V>, E>,
186        allow_nop: bool,
187    ) -> Result<CompResult<K, V>, E>
188    where
189        V: 'static,
190        F: FnOnce(Option<Entry<K, V>>) -> O,
191        E: Send + Sync + 'static,
192    {
193        use std::panic::{catch_unwind, resume_unwind, AssertUnwindSafe};
194
195        let type_id = TypeId::of::<ComputeNone>();
196        let (w_key, w_hash) = self.waiter_key_hash(&c_key, type_id);
197        let waiter = MiniArc::new(RwLock::new(WaiterValue::Computing));
198        // NOTE: We have to acquire a write lock before `try_insert_waiter`,
199        // so that any concurrent attempt will get our lock and wait on it.
200        let mut lock = waiter.write();
201
202        loop {
203            let Some(existing_waiter) = self.try_insert_waiter(w_key.clone(), w_hash, &waiter)
204            else {
205                // Inserted.
206                break;
207            };
208
209            // Somebody else's waiter already exists, so wait for it to finish
210            // (wait for it to release the write lock).
211            let waiter_result = existing_waiter.read();
212            match &*waiter_result {
213                // Unexpected state.
214                WaiterValue::Computing => panic!(
215                    "Got unexpected state `Computing` after resolving `init` future. \
216                    This might be a bug in Moka"
217                ),
218                _ => {
219                    // Try to insert our waiter again.
220                    continue;
221                }
222            }
223        }
224
225        // Our waiter was inserted.
226
227        // Get the current value.
228        let ignore_if = None as Option<&mut fn(&V) -> bool>;
229        let maybe_entry = cache
230            .base
231            .get_with_hash_and_ignore_if(&c_key, c_hash, ignore_if, true);
232        let maybe_value = if allow_nop {
233            maybe_entry.as_ref().map(|ent| ent.value().clone())
234        } else {
235            None
236        };
237        let entry_existed = maybe_entry.is_some();
238
239        // Evaluate the `f` closure. Catching panic is safe here as we will not
240        // evaluate the closure again.
241        let output = match catch_unwind(AssertUnwindSafe(|| f(maybe_entry))) {
242            // Evaluated.
243            Ok(output) => {
244                *lock = WaiterValue::ReadyNone;
245                output
246            }
247            // Panicked.
248            Err(payload) => {
249                *lock = WaiterValue::InitClosurePanicked;
250                // Remove the waiter so that others can retry.
251                self.remove_waiter(w_key, w_hash);
252                resume_unwind(payload);
253            }
254        };
255
256        let op = match post_init(output) {
257            Ok(op) => op,
258            Err(e) => {
259                self.remove_waiter(w_key, w_hash);
260                return Err(e);
261            }
262        };
263
264        let result = match op {
265            Op::Nop => {
266                if let Some(value) = maybe_value {
267                    Ok(CompResult::Unchanged(Entry::new(
268                        Some(c_key),
269                        value,
270                        false,
271                        false,
272                    )))
273                } else {
274                    Ok(CompResult::StillNone(c_key))
275                }
276            }
277            Op::Put(value) => {
278                cache.insert_with_hash(Arc::clone(&c_key), c_hash, value.clone());
279                if entry_existed {
280                    crossbeam_epoch::pin().flush();
281                    let entry = Entry::new(Some(c_key), value, true, true);
282                    Ok(CompResult::ReplacedWith(entry))
283                } else {
284                    let entry = Entry::new(Some(c_key), value, true, false);
285                    Ok(CompResult::Inserted(entry))
286                }
287            }
288            Op::Remove => {
289                let maybe_prev_v = cache.invalidate_with_hash(&c_key, c_hash, true);
290                if let Some(prev_v) = maybe_prev_v {
291                    crossbeam_epoch::pin().flush();
292                    let entry = Entry::new(Some(c_key), prev_v, false, false);
293                    Ok(CompResult::Removed(entry))
294                } else {
295                    Ok(CompResult::StillNone(c_key))
296                }
297            }
298        };
299        self.remove_waiter(w_key, w_hash);
300        result
301
302        // The lock will be unlocked here.
303    }
304
305    /// The `post_init` function for the `get_with` method of cache.
306    pub(crate) fn post_init_for_get_with(value: V) -> Result<V, ()> {
307        Ok(value)
308    }
309
310    /// The `post_init` function for the `optionally_get_with` method of cache.
311    pub(crate) fn post_init_for_optionally_get_with(
312        value: Option<V>,
313    ) -> Result<V, Arc<OptionallyNone>> {
314        // `value` can be either `Some` or `None`. For `None` case, without change
315        // the existing API too much, we will need to convert `None` to Arc<E> here.
316        // `Infallible` could not be instantiated. So it might be good to use an
317        // empty struct to indicate the error type.
318        value.ok_or(Arc::new(OptionallyNone))
319    }
320
321    /// The `post_init` function for `try_get_with` method of cache.
322    pub(crate) fn post_init_for_try_get_with<E>(result: Result<V, E>) -> Result<V, E> {
323        result
324    }
325
326    /// The `post_init` function for the `and_upsert_with` method of cache.
327    pub(crate) fn post_init_for_upsert_with(value: V) -> Result<Op<V>, ()> {
328        Ok(Op::Put(value))
329    }
330
331    /// The `post_init` function for the `and_compute_with` method of cache.
332    pub(crate) fn post_init_for_compute_with(op: Op<V>) -> Result<Op<V>, ()> {
333        Ok(op)
334    }
335
336    /// The `post_init` function for the `and_try_compute_with` method of cache.
337    pub(crate) fn post_init_for_try_compute_with<E>(op: Result<Op<V>, E>) -> Result<Op<V>, E>
338    where
339        E: Send + Sync + 'static,
340    {
341        op
342    }
343
344    /// Returns the `type_id` for `get_with` method of cache.
345    pub(crate) fn type_id_for_get_with() -> TypeId {
346        // NOTE: We use a regular function here instead of a const fn because TypeId
347        // is not stable as a const fn. (as of our MSRV)
348        TypeId::of::<()>()
349    }
350
351    /// Returns the `type_id` for `optionally_get_with` method of cache.
352    pub(crate) fn type_id_for_optionally_get_with() -> TypeId {
353        TypeId::of::<OptionallyNone>()
354    }
355
356    /// Returns the `type_id` for `try_get_with` method of cache.
357    pub(crate) fn type_id_for_try_get_with<E: 'static>() -> TypeId {
358        TypeId::of::<E>()
359    }
360
361    #[inline]
362    fn remove_waiter(&self, w_key: (Arc<K>, TypeId), w_hash: u64) {
363        self.waiters.remove(w_hash, |k| k == &w_key);
364    }
365
366    #[inline]
367    fn try_insert_waiter(
368        &self,
369        w_key: (Arc<K>, TypeId),
370        w_hash: u64,
371        waiter: &Waiter<V>,
372    ) -> Option<Waiter<V>> {
373        let waiter = MiniArc::clone(waiter);
374        self.waiters.insert_if_not_present(w_key, w_hash, waiter)
375    }
376
377    #[inline]
378    fn waiter_key_hash(&self, c_key: &Arc<K>, type_id: TypeId) -> ((Arc<K>, TypeId), u64) {
379        let w_key = (Arc::clone(c_key), type_id);
380        let w_hash = self.waiters.hash(&w_key);
381        (w_key, w_hash)
382    }
383}
384
385#[cfg(test)]
386impl<K, V, S> ValueInitializer<K, V, S> {
387    pub(crate) fn waiter_count(&self) -> usize {
388        self.waiters.len()
389    }
390}