1use std::collections::BTreeMap;
13use std::sync::{Arc, Mutex};
14
15use anyhow::anyhow;
16use async_trait::async_trait;
17use bytes::Bytes;
18use futures_util::{StreamExt, stream};
19use mz_ore::bytes::SegmentedBytes;
20use mz_ore::cast::CastFrom;
21use mz_ore::future::yield_now;
22
23use crate::error::Error;
24use crate::location::{
25    Blob, BlobMetadata, CaSResult, Consensus, Determinate, ExternalError, ResultStream, SeqNo,
26    VersionedData,
27};
28
29#[cfg(test)]
32#[derive(Debug)]
33pub struct MemMultiRegistry {
34    blob_by_path: BTreeMap<String, Arc<tokio::sync::Mutex<MemBlobCore>>>,
35    tombstone: bool,
36}
37
38#[cfg(test)]
39impl MemMultiRegistry {
40    pub fn new(tombstone: bool) -> Self {
42        MemMultiRegistry {
43            blob_by_path: BTreeMap::new(),
44            tombstone,
45        }
46    }
47
48    pub fn blob(&mut self, path: &str) -> MemBlob {
53        if let Some(blob) = self.blob_by_path.get(path) {
54            MemBlob::open(MemBlobConfig {
55                core: Arc::clone(blob),
56            })
57        } else {
58            let blob = Arc::new(tokio::sync::Mutex::new(MemBlobCore {
59                dataz: Default::default(),
60                tombstone: self.tombstone,
61            }));
62            self.blob_by_path
63                .insert(path.to_string(), Arc::clone(&blob));
64            MemBlob::open(MemBlobConfig { core: blob })
65        }
66    }
67}
68
69#[derive(Debug, Default)]
70struct MemBlobCore {
71    dataz: BTreeMap<String, (Bytes, bool)>,
72    tombstone: bool,
73}
74
75impl MemBlobCore {
76    fn get(&self, key: &str) -> Result<Option<Bytes>, ExternalError> {
77        Ok(self
78            .dataz
79            .get(key)
80            .and_then(|(x, exists)| exists.then(|| Bytes::clone(x))))
81    }
82
83    fn set(&mut self, key: &str, value: Bytes) -> Result<(), ExternalError> {
84        self.dataz.insert(key.to_owned(), (value, true));
85        Ok(())
86    }
87
88    fn list_keys_and_metadata(
89        &self,
90        key_prefix: &str,
91        f: &mut (dyn FnMut(BlobMetadata) + Send + Sync),
92    ) -> Result<(), ExternalError> {
93        for (key, (value, exists)) in &self.dataz {
94            if !*exists || !key.starts_with(key_prefix) {
95                continue;
96            }
97
98            f(BlobMetadata {
99                key,
100                size_in_bytes: u64::cast_from(value.len()),
101            });
102        }
103
104        Ok(())
105    }
106
107    fn delete(&mut self, key: &str) -> Result<Option<usize>, ExternalError> {
108        let bytes = if self.tombstone {
109            self.dataz.get_mut(key).and_then(|(x, exists)| {
110                let deleted_size = exists.then(|| x.len());
111                *exists = false;
112                deleted_size
113            })
114        } else {
115            self.dataz.remove(key).map(|(x, _)| x.len())
116        };
117        Ok(bytes)
118    }
119
120    fn restore(&mut self, key: &str) -> Result<(), ExternalError> {
121        match self.dataz.get_mut(key) {
122            None => Err(
123                Determinate::new(anyhow!("unable to restore {key} from in-memory state")).into(),
124            ),
125            Some((_, exists)) => {
126                *exists = true;
127                Ok(())
128            }
129        }
130    }
131}
132
133#[derive(Debug, Default)]
135pub struct MemBlobConfig {
136    core: Arc<tokio::sync::Mutex<MemBlobCore>>,
137}
138
139impl MemBlobConfig {
140    pub fn new(tombstone: bool) -> Self {
142        Self {
143            core: Arc::new(tokio::sync::Mutex::new(MemBlobCore {
144                dataz: Default::default(),
145                tombstone,
146            })),
147        }
148    }
149}
150
151#[derive(Debug)]
153pub struct MemBlob {
154    core: Arc<tokio::sync::Mutex<MemBlobCore>>,
155}
156
157impl MemBlob {
158    pub fn open(config: MemBlobConfig) -> Self {
160        MemBlob { core: config.core }
161    }
162}
163
164#[async_trait]
165impl Blob for MemBlob {
166    async fn get(&self, key: &str) -> Result<Option<SegmentedBytes>, ExternalError> {
167        let () = yield_now().await;
169        let maybe_bytes = self.core.lock().await.get(key)?;
170        Ok(maybe_bytes.map(SegmentedBytes::from))
171    }
172
173    async fn list_keys_and_metadata(
174        &self,
175        key_prefix: &str,
176        f: &mut (dyn FnMut(BlobMetadata) + Send + Sync),
177    ) -> Result<(), ExternalError> {
178        let () = yield_now().await;
180        self.core.lock().await.list_keys_and_metadata(key_prefix, f)
181    }
182
183    async fn set(&self, key: &str, value: Bytes) -> Result<(), ExternalError> {
184        let () = yield_now().await;
186        self.core.lock().await.set(key, value)
188    }
189
190    async fn delete(&self, key: &str) -> Result<Option<usize>, ExternalError> {
191        let () = yield_now().await;
193        self.core.lock().await.delete(key)
194    }
195
196    async fn restore(&self, key: &str) -> Result<(), ExternalError> {
197        let () = yield_now().await;
199        self.core.lock().await.restore(key)
200    }
201}
202
203#[derive(Debug)]
205pub struct MemConsensus {
206    data: Arc<Mutex<BTreeMap<String, Vec<VersionedData>>>>,
209}
210
211impl Default for MemConsensus {
212    fn default() -> Self {
213        Self {
214            data: Arc::new(Mutex::new(BTreeMap::new())),
215        }
216    }
217}
218
219impl MemConsensus {
220    fn scan_store(
221        store: &BTreeMap<String, Vec<VersionedData>>,
222        key: &str,
223        from: SeqNo,
224        limit: usize,
225    ) -> Result<Vec<VersionedData>, ExternalError> {
226        let results = if let Some(values) = store.get(key) {
227            let from_idx = values.partition_point(|x| x.seqno < from);
228            let from_values = &values[from_idx..];
229            let from_values = &from_values[..usize::min(limit, from_values.len())];
230            from_values.to_vec()
231        } else {
232            Vec::new()
233        };
234        Ok(results)
235    }
236}
237
238#[async_trait]
239impl Consensus for MemConsensus {
240    fn list_keys(&self) -> ResultStream<'_, String> {
241        let store = self.data.lock().expect("lock poisoned");
243        let keys: Vec<_> = store.keys().cloned().collect();
244        Box::pin(stream::iter(keys).map(Ok))
245    }
246
247    async fn head(&self, key: &str) -> Result<Option<VersionedData>, ExternalError> {
248        let () = yield_now().await;
250        let store = self.data.lock().map_err(Error::from)?;
251        let values = match store.get(key) {
252            None => return Ok(None),
253            Some(values) => values,
254        };
255
256        Ok(values.last().cloned())
257    }
258
259    async fn compare_and_set(
260        &self,
261        key: &str,
262        expected: Option<SeqNo>,
263        new: VersionedData,
264    ) -> Result<CaSResult, ExternalError> {
265        let () = yield_now().await;
267        if let Some(expected) = expected {
268            if new.seqno <= expected {
269                return Err(ExternalError::from(anyhow!(
270                    "new seqno must be strictly greater than expected. Got new: {:?} expected: {:?}",
271                    new.seqno,
272                    expected
273                )));
274            }
275        }
276
277        if new.seqno.0 > i64::MAX.try_into().expect("i64::MAX known to fit in u64") {
278            return Err(ExternalError::from(anyhow!(
279                "sequence numbers must fit within [0, i64::MAX], received: {:?}",
280                new.seqno
281            )));
282        }
283        let mut store = self.data.lock().map_err(Error::from)?;
284
285        let data = match store.get(key) {
286            None => None,
287            Some(values) => values.last(),
288        };
289
290        let seqno = data.as_ref().map(|data| data.seqno);
291
292        if seqno != expected {
293            return Ok(CaSResult::ExpectationMismatch);
294        }
295
296        store.entry(key.to_string()).or_default().push(new);
297
298        Ok(CaSResult::Committed)
299    }
300
301    async fn scan(
302        &self,
303        key: &str,
304        from: SeqNo,
305        limit: usize,
306    ) -> Result<Vec<VersionedData>, ExternalError> {
307        let () = yield_now().await;
309        let store = self.data.lock().map_err(Error::from)?;
310        Self::scan_store(&store, key, from, limit)
311    }
312
313    async fn truncate(&self, key: &str, seqno: SeqNo) -> Result<usize, ExternalError> {
314        let () = yield_now().await;
316        let current = self.head(key).await?;
317        if current.map_or(true, |data| data.seqno < seqno) {
318            return Err(ExternalError::from(anyhow!(
319                "upper bound too high for truncate: {:?}",
320                seqno
321            )));
322        }
323
324        let mut store = self.data.lock().map_err(Error::from)?;
325
326        let mut deleted = 0;
327        if let Some(values) = store.get_mut(key) {
328            let count_before = values.len();
329            values.retain(|val| val.seqno >= seqno);
330            deleted += count_before - values.len();
331        }
332
333        Ok(deleted)
334    }
335}
336
337#[cfg(test)]
338mod tests {
339    use crate::location::tests::{blob_impl_test, consensus_impl_test};
340
341    use super::*;
342
343    #[mz_ore::test(tokio::test)]
344    #[cfg_attr(miri, ignore)] async fn mem_blob() -> Result<(), ExternalError> {
346        let registry = Arc::new(tokio::sync::Mutex::new(MemMultiRegistry::new(false)));
347        blob_impl_test(move |path| {
348            let path = path.to_owned();
349            let registry = Arc::clone(®istry);
350            async move { Ok(registry.lock().await.blob(&path)) }
351        })
352        .await?;
353
354        let registry = Arc::new(tokio::sync::Mutex::new(MemMultiRegistry::new(true)));
355        blob_impl_test(move |path| {
356            let path = path.to_owned();
357            let registry = Arc::clone(®istry);
358            async move { Ok(registry.lock().await.blob(&path)) }
359        })
360        .await?;
361
362        Ok(())
363    }
364
365    #[mz_ore::test(tokio::test)]
366    #[cfg_attr(miri, ignore)] async fn mem_consensus() -> Result<(), ExternalError> {
368        consensus_impl_test(|| async { Ok(MemConsensus::default()) }).await
369    }
370}