1use std::collections::BTreeMap;
13use std::sync::{Arc, Mutex};
14
15use anyhow::anyhow;
16use async_trait::async_trait;
17use bytes::Bytes;
18use futures_util::{StreamExt, stream};
19use mz_ore::bytes::SegmentedBytes;
20use mz_ore::cast::CastFrom;
21use mz_ore::future::yield_now;
22
23use crate::error::Error;
24use crate::location::{
25 Blob, BlobMetadata, CaSResult, Consensus, Determinate, ExternalError, ResultStream, SeqNo,
26 VersionedData,
27};
28
29#[cfg(test)]
32#[derive(Debug)]
33pub struct MemMultiRegistry {
34 blob_by_path: BTreeMap<String, Arc<tokio::sync::Mutex<MemBlobCore>>>,
35 tombstone: bool,
36}
37
38#[cfg(test)]
39impl MemMultiRegistry {
40 pub fn new(tombstone: bool) -> Self {
42 MemMultiRegistry {
43 blob_by_path: BTreeMap::new(),
44 tombstone,
45 }
46 }
47
48 pub fn blob(&mut self, path: &str) -> MemBlob {
53 if let Some(blob) = self.blob_by_path.get(path) {
54 MemBlob::open(MemBlobConfig {
55 core: Arc::clone(blob),
56 })
57 } else {
58 let blob = Arc::new(tokio::sync::Mutex::new(MemBlobCore {
59 dataz: Default::default(),
60 tombstone: self.tombstone,
61 }));
62 self.blob_by_path
63 .insert(path.to_string(), Arc::clone(&blob));
64 MemBlob::open(MemBlobConfig { core: blob })
65 }
66 }
67}
68
69#[derive(Debug, Default)]
70struct MemBlobCore {
71 dataz: BTreeMap<String, (Bytes, bool)>,
72 tombstone: bool,
73}
74
75impl MemBlobCore {
76 fn get(&self, key: &str) -> Result<Option<Bytes>, ExternalError> {
77 Ok(self
78 .dataz
79 .get(key)
80 .and_then(|(x, exists)| exists.then(|| Bytes::clone(x))))
81 }
82
83 fn set(&mut self, key: &str, value: Bytes) -> Result<(), ExternalError> {
84 self.dataz.insert(key.to_owned(), (value, true));
85 Ok(())
86 }
87
88 fn list_keys_and_metadata(
89 &self,
90 key_prefix: &str,
91 f: &mut (dyn FnMut(BlobMetadata) + Send + Sync),
92 ) -> Result<(), ExternalError> {
93 for (key, (value, exists)) in &self.dataz {
94 if !*exists || !key.starts_with(key_prefix) {
95 continue;
96 }
97
98 f(BlobMetadata {
99 key,
100 size_in_bytes: u64::cast_from(value.len()),
101 });
102 }
103
104 Ok(())
105 }
106
107 fn delete(&mut self, key: &str) -> Result<Option<usize>, ExternalError> {
108 let bytes = if self.tombstone {
109 self.dataz.get_mut(key).and_then(|(x, exists)| {
110 let deleted_size = exists.then(|| x.len());
111 *exists = false;
112 deleted_size
113 })
114 } else {
115 self.dataz.remove(key).map(|(x, _)| x.len())
116 };
117 Ok(bytes)
118 }
119
120 fn restore(&mut self, key: &str) -> Result<(), ExternalError> {
121 match self.dataz.get_mut(key) {
122 None => Err(
123 Determinate::new(anyhow!("unable to restore {key} from in-memory state")).into(),
124 ),
125 Some((_, exists)) => {
126 *exists = true;
127 Ok(())
128 }
129 }
130 }
131}
132
133#[derive(Debug, Default)]
135pub struct MemBlobConfig {
136 core: Arc<tokio::sync::Mutex<MemBlobCore>>,
137}
138
139impl MemBlobConfig {
140 pub fn new(tombstone: bool) -> Self {
142 Self {
143 core: Arc::new(tokio::sync::Mutex::new(MemBlobCore {
144 dataz: Default::default(),
145 tombstone,
146 })),
147 }
148 }
149}
150
151#[derive(Debug)]
153pub struct MemBlob {
154 core: Arc<tokio::sync::Mutex<MemBlobCore>>,
155}
156
157impl MemBlob {
158 pub fn open(config: MemBlobConfig) -> Self {
160 MemBlob { core: config.core }
161 }
162}
163
164#[async_trait]
165impl Blob for MemBlob {
166 async fn get(&self, key: &str) -> Result<Option<SegmentedBytes>, ExternalError> {
167 let () = yield_now().await;
169 let maybe_bytes = self.core.lock().await.get(key)?;
170 Ok(maybe_bytes.map(SegmentedBytes::from))
171 }
172
173 async fn list_keys_and_metadata(
174 &self,
175 key_prefix: &str,
176 f: &mut (dyn FnMut(BlobMetadata) + Send + Sync),
177 ) -> Result<(), ExternalError> {
178 let () = yield_now().await;
180 self.core.lock().await.list_keys_and_metadata(key_prefix, f)
181 }
182
183 async fn set(&self, key: &str, value: Bytes) -> Result<(), ExternalError> {
184 let () = yield_now().await;
186 self.core.lock().await.set(key, value)
188 }
189
190 async fn delete(&self, key: &str) -> Result<Option<usize>, ExternalError> {
191 let () = yield_now().await;
193 self.core.lock().await.delete(key)
194 }
195
196 async fn restore(&self, key: &str) -> Result<(), ExternalError> {
197 let () = yield_now().await;
199 self.core.lock().await.restore(key)
200 }
201}
202
203#[derive(Debug)]
205pub struct MemConsensus {
206 data: Arc<Mutex<BTreeMap<String, Vec<VersionedData>>>>,
209}
210
211impl Default for MemConsensus {
212 fn default() -> Self {
213 Self {
214 data: Arc::new(Mutex::new(BTreeMap::new())),
215 }
216 }
217}
218
219impl MemConsensus {
220 fn scan_store(
221 store: &BTreeMap<String, Vec<VersionedData>>,
222 key: &str,
223 from: SeqNo,
224 limit: usize,
225 ) -> Result<Vec<VersionedData>, ExternalError> {
226 let results = if let Some(values) = store.get(key) {
227 let from_idx = values.partition_point(|x| x.seqno < from);
228 let from_values = &values[from_idx..];
229 let from_values = &from_values[..usize::min(limit, from_values.len())];
230 from_values.to_vec()
231 } else {
232 Vec::new()
233 };
234 Ok(results)
235 }
236}
237
238#[async_trait]
239impl Consensus for MemConsensus {
240 fn list_keys(&self) -> ResultStream<'_, String> {
241 let store = self.data.lock().expect("lock poisoned");
243 let keys: Vec<_> = store.keys().cloned().collect();
244 Box::pin(stream::iter(keys).map(Ok))
245 }
246
247 async fn head(&self, key: &str) -> Result<Option<VersionedData>, ExternalError> {
248 let () = yield_now().await;
250 let store = self.data.lock().map_err(Error::from)?;
251 let values = match store.get(key) {
252 None => return Ok(None),
253 Some(values) => values,
254 };
255
256 Ok(values.last().cloned())
257 }
258
259 async fn compare_and_set(
260 &self,
261 key: &str,
262 expected: Option<SeqNo>,
263 new: VersionedData,
264 ) -> Result<CaSResult, ExternalError> {
265 let () = yield_now().await;
267 if let Some(expected) = expected {
268 if new.seqno <= expected {
269 return Err(ExternalError::from(anyhow!(
270 "new seqno must be strictly greater than expected. Got new: {:?} expected: {:?}",
271 new.seqno,
272 expected
273 )));
274 }
275 }
276
277 if new.seqno.0 > i64::MAX.try_into().expect("i64::MAX known to fit in u64") {
278 return Err(ExternalError::from(anyhow!(
279 "sequence numbers must fit within [0, i64::MAX], received: {:?}",
280 new.seqno
281 )));
282 }
283 let mut store = self.data.lock().map_err(Error::from)?;
284
285 let data = match store.get(key) {
286 None => None,
287 Some(values) => values.last(),
288 };
289
290 let seqno = data.as_ref().map(|data| data.seqno);
291
292 if seqno != expected {
293 return Ok(CaSResult::ExpectationMismatch);
294 }
295
296 store.entry(key.to_string()).or_default().push(new);
297
298 Ok(CaSResult::Committed)
299 }
300
301 async fn scan(
302 &self,
303 key: &str,
304 from: SeqNo,
305 limit: usize,
306 ) -> Result<Vec<VersionedData>, ExternalError> {
307 let () = yield_now().await;
309 let store = self.data.lock().map_err(Error::from)?;
310 Self::scan_store(&store, key, from, limit)
311 }
312
313 async fn truncate(&self, key: &str, seqno: SeqNo) -> Result<usize, ExternalError> {
314 let () = yield_now().await;
316 let current = self.head(key).await?;
317 if current.map_or(true, |data| data.seqno < seqno) {
318 return Err(ExternalError::from(anyhow!(
319 "upper bound too high for truncate: {:?}",
320 seqno
321 )));
322 }
323
324 let mut store = self.data.lock().map_err(Error::from)?;
325
326 let mut deleted = 0;
327 if let Some(values) = store.get_mut(key) {
328 let count_before = values.len();
329 values.retain(|val| val.seqno >= seqno);
330 deleted += count_before - values.len();
331 }
332
333 Ok(deleted)
334 }
335}
336
337#[cfg(test)]
338mod tests {
339 use crate::location::tests::{blob_impl_test, consensus_impl_test};
340
341 use super::*;
342
343 #[mz_ore::test(tokio::test)]
344 #[cfg_attr(miri, ignore)] async fn mem_blob() -> Result<(), ExternalError> {
346 let registry = Arc::new(tokio::sync::Mutex::new(MemMultiRegistry::new(false)));
347 blob_impl_test(move |path| {
348 let path = path.to_owned();
349 let registry = Arc::clone(®istry);
350 async move { Ok(registry.lock().await.blob(&path)) }
351 })
352 .await?;
353
354 let registry = Arc::new(tokio::sync::Mutex::new(MemMultiRegistry::new(true)));
355 blob_impl_test(move |path| {
356 let path = path.to_owned();
357 let registry = Arc::clone(®istry);
358 async move { Ok(registry.lock().await.blob(&path)) }
359 })
360 .await?;
361
362 Ok(())
363 }
364
365 #[mz_ore::test(tokio::test)]
366 #[cfg_attr(miri, ignore)] async fn mem_consensus() -> Result<(), ExternalError> {
368 consensus_impl_test(|| async { Ok(MemConsensus::default()) }).await
369 }
370}