1use std::collections::BTreeMap;
13use std::future::Future;
14use std::pin::Pin;
15use std::sync::{Arc, Mutex};
16use std::task::{Context, Poll};
17
18use anyhow::anyhow;
19use async_trait::async_trait;
20use bytes::Bytes;
21use futures_util::{StreamExt, stream};
22use mz_ore::bytes::SegmentedBytes;
23use mz_ore::cast::CastFrom;
24
25use crate::error::Error;
26use crate::location::{
27 Blob, BlobMetadata, CaSResult, Consensus, Determinate, ExternalError, ResultStream, SeqNo,
28 VersionedData,
29};
30
31async fn yield_now() {
36 struct YieldNow {
37 yielded: bool,
38 }
39
40 impl Future for YieldNow {
41 type Output = ();
42
43 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
44 if self.yielded {
45 return Poll::Ready(());
46 }
47
48 self.yielded = true;
49 cx.waker().wake_by_ref();
50 Poll::Pending
51 }
52 }
53
54 YieldNow { yielded: false }.await
55}
56
57#[cfg(test)]
60#[derive(Debug)]
61pub struct MemMultiRegistry {
62 blob_by_path: BTreeMap<String, Arc<tokio::sync::Mutex<MemBlobCore>>>,
63 tombstone: bool,
64}
65
66#[cfg(test)]
67impl MemMultiRegistry {
68 pub fn new(tombstone: bool) -> Self {
70 MemMultiRegistry {
71 blob_by_path: BTreeMap::new(),
72 tombstone,
73 }
74 }
75
76 pub fn blob(&mut self, path: &str) -> MemBlob {
81 if let Some(blob) = self.blob_by_path.get(path) {
82 MemBlob::open(MemBlobConfig {
83 core: Arc::clone(blob),
84 })
85 } else {
86 let blob = Arc::new(tokio::sync::Mutex::new(MemBlobCore {
87 dataz: Default::default(),
88 tombstone: self.tombstone,
89 }));
90 self.blob_by_path
91 .insert(path.to_string(), Arc::clone(&blob));
92 MemBlob::open(MemBlobConfig { core: blob })
93 }
94 }
95}
96
97#[derive(Debug, Default)]
98struct MemBlobCore {
99 dataz: BTreeMap<String, (Bytes, bool)>,
100 tombstone: bool,
101}
102
103impl MemBlobCore {
104 fn get(&self, key: &str) -> Result<Option<Bytes>, ExternalError> {
105 Ok(self
106 .dataz
107 .get(key)
108 .and_then(|(x, exists)| exists.then(|| Bytes::clone(x))))
109 }
110
111 fn set(&mut self, key: &str, value: Bytes) -> Result<(), ExternalError> {
112 self.dataz.insert(key.to_owned(), (value, true));
113 Ok(())
114 }
115
116 fn list_keys_and_metadata(
117 &self,
118 key_prefix: &str,
119 f: &mut (dyn FnMut(BlobMetadata) + Send + Sync),
120 ) -> Result<(), ExternalError> {
121 for (key, (value, exists)) in &self.dataz {
122 if !*exists || !key.starts_with(key_prefix) {
123 continue;
124 }
125
126 f(BlobMetadata {
127 key,
128 size_in_bytes: u64::cast_from(value.len()),
129 });
130 }
131
132 Ok(())
133 }
134
135 fn delete(&mut self, key: &str) -> Result<Option<usize>, ExternalError> {
136 let bytes = if self.tombstone {
137 self.dataz.get_mut(key).and_then(|(x, exists)| {
138 let deleted_size = exists.then(|| x.len());
139 *exists = false;
140 deleted_size
141 })
142 } else {
143 self.dataz.remove(key).map(|(x, _)| x.len())
144 };
145 Ok(bytes)
146 }
147
148 fn restore(&mut self, key: &str) -> Result<(), ExternalError> {
149 match self.dataz.get_mut(key) {
150 None => Err(
151 Determinate::new(anyhow!("unable to restore {key} from in-memory state")).into(),
152 ),
153 Some((_, exists)) => {
154 *exists = true;
155 Ok(())
156 }
157 }
158 }
159}
160
161#[derive(Debug, Default)]
163pub struct MemBlobConfig {
164 core: Arc<tokio::sync::Mutex<MemBlobCore>>,
165}
166
167impl MemBlobConfig {
168 pub fn new(tombstone: bool) -> Self {
170 Self {
171 core: Arc::new(tokio::sync::Mutex::new(MemBlobCore {
172 dataz: Default::default(),
173 tombstone,
174 })),
175 }
176 }
177}
178
179#[derive(Debug)]
181pub struct MemBlob {
182 core: Arc<tokio::sync::Mutex<MemBlobCore>>,
183}
184
185impl MemBlob {
186 pub fn open(config: MemBlobConfig) -> Self {
188 MemBlob { core: config.core }
189 }
190}
191
192#[async_trait]
193impl Blob for MemBlob {
194 async fn get(&self, key: &str) -> Result<Option<SegmentedBytes>, ExternalError> {
195 let () = yield_now().await;
197 let maybe_bytes = self.core.lock().await.get(key)?;
198 Ok(maybe_bytes.map(SegmentedBytes::from))
199 }
200
201 async fn list_keys_and_metadata(
202 &self,
203 key_prefix: &str,
204 f: &mut (dyn FnMut(BlobMetadata) + Send + Sync),
205 ) -> Result<(), ExternalError> {
206 let () = yield_now().await;
208 self.core.lock().await.list_keys_and_metadata(key_prefix, f)
209 }
210
211 async fn set(&self, key: &str, value: Bytes) -> Result<(), ExternalError> {
212 let () = yield_now().await;
214 self.core.lock().await.set(key, value)
216 }
217
218 async fn delete(&self, key: &str) -> Result<Option<usize>, ExternalError> {
219 let () = yield_now().await;
221 self.core.lock().await.delete(key)
222 }
223
224 async fn restore(&self, key: &str) -> Result<(), ExternalError> {
225 let () = yield_now().await;
227 self.core.lock().await.restore(key)
228 }
229}
230
231#[derive(Debug)]
233pub struct MemConsensus {
234 data: Arc<Mutex<BTreeMap<String, Vec<VersionedData>>>>,
237}
238
239impl Default for MemConsensus {
240 fn default() -> Self {
241 Self {
242 data: Arc::new(Mutex::new(BTreeMap::new())),
243 }
244 }
245}
246
247impl MemConsensus {
248 fn scan_store(
249 store: &BTreeMap<String, Vec<VersionedData>>,
250 key: &str,
251 from: SeqNo,
252 limit: usize,
253 ) -> Result<Vec<VersionedData>, ExternalError> {
254 let results = if let Some(values) = store.get(key) {
255 let from_idx = values.partition_point(|x| x.seqno < from);
256 let from_values = &values[from_idx..];
257 let from_values = &from_values[..usize::min(limit, from_values.len())];
258 from_values.to_vec()
259 } else {
260 Vec::new()
261 };
262 Ok(results)
263 }
264}
265
266#[async_trait]
267impl Consensus for MemConsensus {
268 fn list_keys(&self) -> ResultStream<String> {
269 let store = self.data.lock().expect("lock poisoned");
271 let keys: Vec<_> = store.keys().cloned().collect();
272 Box::pin(stream::iter(keys).map(Ok))
273 }
274
275 async fn head(&self, key: &str) -> Result<Option<VersionedData>, ExternalError> {
276 let () = yield_now().await;
278 let store = self.data.lock().map_err(Error::from)?;
279 let values = match store.get(key) {
280 None => return Ok(None),
281 Some(values) => values,
282 };
283
284 Ok(values.last().cloned())
285 }
286
287 async fn compare_and_set(
288 &self,
289 key: &str,
290 expected: Option<SeqNo>,
291 new: VersionedData,
292 ) -> Result<CaSResult, ExternalError> {
293 let () = yield_now().await;
295 if let Some(expected) = expected {
296 if new.seqno <= expected {
297 return Err(ExternalError::from(anyhow!(
298 "new seqno must be strictly greater than expected. Got new: {:?} expected: {:?}",
299 new.seqno,
300 expected
301 )));
302 }
303 }
304
305 if new.seqno.0 > i64::MAX.try_into().expect("i64::MAX known to fit in u64") {
306 return Err(ExternalError::from(anyhow!(
307 "sequence numbers must fit within [0, i64::MAX], received: {:?}",
308 new.seqno
309 )));
310 }
311 let mut store = self.data.lock().map_err(Error::from)?;
312
313 let data = match store.get(key) {
314 None => None,
315 Some(values) => values.last(),
316 };
317
318 let seqno = data.as_ref().map(|data| data.seqno);
319
320 if seqno != expected {
321 return Ok(CaSResult::ExpectationMismatch);
322 }
323
324 store.entry(key.to_string()).or_default().push(new);
325
326 Ok(CaSResult::Committed)
327 }
328
329 async fn scan(
330 &self,
331 key: &str,
332 from: SeqNo,
333 limit: usize,
334 ) -> Result<Vec<VersionedData>, ExternalError> {
335 let () = yield_now().await;
337 let store = self.data.lock().map_err(Error::from)?;
338 Self::scan_store(&store, key, from, limit)
339 }
340
341 async fn truncate(&self, key: &str, seqno: SeqNo) -> Result<usize, ExternalError> {
342 let () = yield_now().await;
344 let current = self.head(key).await?;
345 if current.map_or(true, |data| data.seqno < seqno) {
346 return Err(ExternalError::from(anyhow!(
347 "upper bound too high for truncate: {:?}",
348 seqno
349 )));
350 }
351
352 let mut store = self.data.lock().map_err(Error::from)?;
353
354 let mut deleted = 0;
355 if let Some(values) = store.get_mut(key) {
356 let count_before = values.len();
357 values.retain(|val| val.seqno >= seqno);
358 deleted += count_before - values.len();
359 }
360
361 Ok(deleted)
362 }
363}
364
365#[cfg(test)]
366mod tests {
367 use crate::location::tests::{blob_impl_test, consensus_impl_test};
368
369 use super::*;
370
371 #[mz_ore::test(tokio::test)]
372 #[cfg_attr(miri, ignore)] async fn mem_blob() -> Result<(), ExternalError> {
374 let registry = Arc::new(tokio::sync::Mutex::new(MemMultiRegistry::new(false)));
375 blob_impl_test(move |path| {
376 let path = path.to_owned();
377 let registry = Arc::clone(®istry);
378 async move { Ok(registry.lock().await.blob(&path)) }
379 })
380 .await?;
381
382 let registry = Arc::new(tokio::sync::Mutex::new(MemMultiRegistry::new(true)));
383 blob_impl_test(move |path| {
384 let path = path.to_owned();
385 let registry = Arc::clone(®istry);
386 async move { Ok(registry.lock().await.blob(&path)) }
387 })
388 .await?;
389
390 Ok(())
391 }
392
393 #[mz_ore::test(tokio::test)]
394 #[cfg_attr(miri, ignore)] async fn mem_consensus() -> Result<(), ExternalError> {
396 consensus_impl_test(|| async { Ok(MemConsensus::default()) }).await
397 }
398}