launchdarkly_server_sdk/stores/
persistent_store_wrapper.rs

1use std::collections::HashMap;
2use std::convert::{TryFrom, TryInto};
3use std::iter::FromIterator;
4use std::time::Duration;
5
6use launchdarkly_server_sdk_evaluation::{Flag, Segment, Store};
7
8use super::persistent_store::PersistentDataStore;
9use super::persistent_store_cache::CachePair;
10use super::store::{DataStore, UpdateError};
11use super::store_types::{
12    AllData, DataKind, PatchTarget, SerializeToSerializedItem, SerializedItem, StorageItem,
13};
14
15trait WithKind {
16    const KIND: DataKind;
17}
18
19impl WithKind for StorageItem<Flag> {
20    const KIND: DataKind = DataKind::Flag;
21}
22
23impl WithKind for StorageItem<Segment> {
24    const KIND: DataKind = DataKind::Segment;
25}
26
27pub(super) struct PersistentDataStoreWrapper {
28    store: Box<dyn PersistentDataStore>,
29    flags: CachePair<Flag>,
30    segments: CachePair<Segment>,
31}
32
33impl PersistentDataStoreWrapper {
34    pub(super) fn new(store: Box<dyn PersistentDataStore>, cache_ttl: Option<Duration>) -> Self {
35        Self {
36            store,
37            flags: CachePair::new(String::from("flags"), cache_ttl),
38            segments: CachePair::new(String::from("segments"), cache_ttl),
39        }
40    }
41
42    fn upsert_storage_item<T>(
43        &mut self,
44        key: &str,
45        data: StorageItem<T>,
46    ) -> Result<bool, UpdateError>
47    where
48        StorageItem<T>: WithKind,
49        StorageItem<T>: SerializeToSerializedItem,
50    {
51        let serialized = data
52            .serialize_to_serialized_item()
53            .map_err(UpdateError::ParseError)?;
54        let was_updated = self.store.upsert(StorageItem::<T>::KIND, key, serialized)?;
55
56        Ok(was_updated)
57    }
58
59    fn add_to_cache<T: 'static + Sync + Send + Clone>(
60        was_updated: bool,
61        cache: &CachePair<T>,
62        key: &str,
63        data: StorageItem<T>,
64    ) {
65        if was_updated {
66            cache.insert_single(data.clone(), key);
67            // If the cache is infinite, we need to update the all flags cache. Otherwise, we can
68            // just invalidate the cache and let it re-populate the next time it is required.
69            if cache.cache_is_infinite() {
70                if let Some(mut map) = cache.get_all() {
71                    map.insert(key.to_string(), data);
72                    cache.insert_all(map);
73                }
74            } else {
75                cache.invalidate_all();
76            }
77        } else {
78            cache.invalidate_all();
79            cache.invalidate_single(key);
80        }
81    }
82
83    fn upsert_flag(&mut self, flag_key: &str, data: StorageItem<Flag>) -> Result<(), UpdateError> {
84        let was_updated = self.upsert_storage_item(flag_key, data.clone())?;
85
86        Self::add_to_cache(was_updated, &self.flags, flag_key, data);
87        if !was_updated {
88            let _ = self.flag(flag_key); // Force repopulating the cache
89        }
90
91        Ok(())
92    }
93
94    fn upsert_segment(
95        &mut self,
96        segment_key: &str,
97        data: StorageItem<Segment>,
98    ) -> Result<(), UpdateError> {
99        let was_updated = self.upsert_storage_item(segment_key, data.clone())?;
100
101        Self::add_to_cache(was_updated, &self.segments, segment_key, data);
102        if !was_updated {
103            let _ = self.segment(segment_key); // Force repopulating the cache
104        }
105
106        Ok(())
107    }
108
109    fn cache_flags(&self, flags: HashMap<String, StorageItem<Flag>>) {
110        self.flags.insert_all(flags.clone());
111
112        flags.into_iter().for_each(|(key, flag)| {
113            self.flags.insert_single(flag, &key);
114        });
115    }
116
117    fn cache_segments(&self, segments: HashMap<String, StorageItem<Segment>>) {
118        self.segments.insert_all(segments.clone());
119
120        segments.into_iter().for_each(|(key, segment)| {
121            self.segments.insert_single(segment, &key);
122        });
123    }
124
125    fn cache_items(&self, all_data: AllData<StorageItem<Flag>, StorageItem<Segment>>) {
126        self.cache_flags(all_data.flags);
127        self.cache_segments(all_data.segments);
128        debug!("flag and segment caches have been updated");
129    }
130}
131
132impl Store for PersistentDataStoreWrapper {
133    fn flag(&self, key: &str) -> Option<Flag> {
134        if let Some(item) = self.flags.get_one(key) {
135            return item.into();
136        }
137
138        match self.store.flag(key) {
139            Ok(Some(serialized_item)) => {
140                let storage_item: Result<StorageItem<Flag>, serde_json::Error> =
141                    serialized_item.try_into();
142                match storage_item {
143                    Ok(item) => {
144                        self.flags.insert_single(item.clone(), key);
145                        item.into()
146                    }
147                    Err(e) => {
148                        warn!("failed to convert serialized item into flag: {}", e);
149                        None
150                    }
151                }
152            }
153            Ok(None) => None,
154            Err(e) => {
155                warn!("persistent store failed to retrieve flag: {}", e);
156                None
157            }
158        }
159    }
160
161    fn segment(&self, key: &str) -> Option<Segment> {
162        if let Some(item) = self.segments.get_one(key) {
163            return item.into();
164        }
165
166        match self.store.segment(key) {
167            Ok(Some(serialized_item)) => {
168                let storage_item: Result<StorageItem<Segment>, serde_json::Error> =
169                    serialized_item.try_into();
170                match storage_item {
171                    Ok(item) => {
172                        self.segments.insert_single(item.clone(), key);
173                        item.into()
174                    }
175                    Err(e) => {
176                        warn!("failed to convert serialized item into segment: {}", e);
177                        None
178                    }
179                }
180            }
181            Ok(None) => None,
182            Err(e) => {
183                warn!("persistent store failed to retrieve segment: {}", e);
184                None
185            }
186        }
187    }
188}
189
190impl DataStore for PersistentDataStoreWrapper {
191    fn init(&mut self, all_data: AllData<Flag, Segment>) {
192        self.flags.invalidate_everything();
193        self.segments.invalidate_everything();
194
195        let serialized_data = AllData::<SerializedItem, SerializedItem>::try_from(all_data.clone());
196
197        match serialized_data {
198            Err(e) => warn!(
199                "failed to deserialize payload; cannot initialize store {}",
200                e
201            ),
202            Ok(data) => {
203                let result = self.store.init(data);
204
205                match result {
206                    Ok(()) => {
207                        debug!("data store has been updated with new flag data");
208                        self.cache_items(all_data.into());
209                    }
210                    Err(e) => {
211                        error!("failed to init store: {}", e);
212                        if self.flags.cache_is_infinite() {
213                            debug!("updating non-expiring cache");
214                            self.cache_items(all_data.into())
215                        }
216                    }
217                };
218            }
219        }
220    }
221
222    fn all_flags(&self) -> HashMap<String, Flag> {
223        if let Some(flag_items) = self.flags.get_all() {
224            let flag_iter = flag_items.into_iter().filter_map(|(key, item)| match item {
225                StorageItem::Item(flag) => Some((key, flag)),
226                StorageItem::Tombstone(_) => None,
227            });
228            return HashMap::from_iter(flag_iter);
229        }
230
231        match self.store.all_flags() {
232            Ok(serialized_flags) => {
233                let flags: Result<HashMap<String, StorageItem<Flag>>, serde_json::Error> =
234                    serialized_flags
235                        .into_iter()
236                        .map(|(key, flag)| match flag.try_into() {
237                            Ok(item) => Ok((key, item)),
238                            Err(e) => Err(e),
239                        })
240                        .collect();
241
242                match flags {
243                    Ok(flags) => {
244                        self.cache_flags(flags.clone());
245                        let flag_iter = flags.into_iter().filter_map(|(key, item)| match item {
246                            StorageItem::Item(flag) => Some((key, flag)),
247                            StorageItem::Tombstone(_) => None,
248                        });
249                        HashMap::from_iter(flag_iter)
250                    }
251                    Err(e) => {
252                        warn!("failed to convert serialized items into flags: {}", e);
253                        HashMap::new()
254                    }
255                }
256            }
257            Err(e) => {
258                warn!("persistent store failed to retrieve all flags: {}", e);
259                HashMap::new()
260            }
261        }
262    }
263
264    fn upsert(&mut self, key: &str, data: PatchTarget) -> Result<(), UpdateError> {
265        match data {
266            PatchTarget::Flag(item) => self.upsert_flag(key, item),
267            PatchTarget::Segment(item) => self.upsert_segment(key, item),
268            PatchTarget::Other(v) => Err(UpdateError::InvalidTarget(
269                "flag or segment".to_string(),
270                format!("{:?}", v),
271            )),
272        }
273    }
274
275    fn to_store(&self) -> &dyn Store {
276        self
277    }
278}
279
280#[cfg(test)]
281mod tests {
282    use crate::stores::{
283        persistent_store::tests::InMemoryPersistentDataStore,
284        store::DataStore,
285        store_types::{PatchTarget, StorageItem},
286    };
287    use launchdarkly_server_sdk_evaluation::Store;
288    use maplit::hashmap;
289
290    use crate::stores::{persistent_store::tests::NullPersistentDataStore, store_types::AllData};
291    use crate::test_common::{basic_flag, basic_segment};
292    use std::{collections::HashMap, time::Duration};
293
294    use super::PersistentDataStoreWrapper;
295
296    #[test]
297    fn can_retrieve_flags_without_cache() {
298        let store = InMemoryPersistentDataStore {
299            data: AllData {
300                flags: HashMap::new(),
301                segments: HashMap::new(),
302            },
303            initialized: true,
304        };
305
306        let mut wrapper =
307            PersistentDataStoreWrapper::new(Box::new(store), Some(Duration::from_secs(0)));
308
309        let all_data = AllData {
310            flags: hashmap!["flag".into() => basic_flag("flag")],
311            segments: hashmap!["segment".into() => basic_segment("segment")],
312        };
313
314        wrapper.init(all_data);
315        assert_eq!(1, wrapper.all_flags().len());
316
317        let flag = wrapper.flag("flag").unwrap();
318        assert_eq!(flag.key, "flag");
319
320        let segment = wrapper.segment("segment").unwrap();
321        assert_eq!(segment.key, "segment");
322    }
323
324    #[test]
325    fn retrieving_flags_uses_cache() {
326        let store = NullPersistentDataStore { initialized: false };
327
328        let mut wrapper =
329            PersistentDataStoreWrapper::new(Box::new(store), Some(Duration::from_secs(100)));
330
331        let initial_flag = basic_flag("flag");
332
333        let all_data = AllData {
334            flags: hashmap!["flag".into() => initial_flag],
335            segments: HashMap::new(),
336        };
337
338        wrapper.init(all_data);
339
340        assert!(wrapper.flag("flag").is_some());
341
342        let updated_flag = basic_flag("updated-flag");
343
344        assert!(wrapper
345            .upsert("flag", PatchTarget::Flag(StorageItem::Item(updated_flag)))
346            .is_ok());
347
348        let flag = wrapper.flag("flag").unwrap();
349        assert_eq!(flag.key, "updated-flag");
350    }
351
352    #[test]
353    fn retrieving_segments_uses_cache() {
354        let store = NullPersistentDataStore { initialized: false };
355
356        let mut wrapper =
357            PersistentDataStoreWrapper::new(Box::new(store), Some(Duration::from_secs(100)));
358
359        let initial_segment = basic_segment("segment");
360
361        let all_data = AllData {
362            flags: HashMap::new(),
363            segments: hashmap!["segment".into() => initial_segment],
364        };
365
366        wrapper.init(all_data);
367
368        assert!(wrapper.segment("segment").is_some());
369
370        let updated_segment = basic_segment("updated-segment");
371
372        assert!(wrapper
373            .upsert(
374                "segment",
375                PatchTarget::Segment(StorageItem::Item(updated_segment))
376            )
377            .is_ok());
378
379        let segment = wrapper.segment("segment").unwrap();
380        assert_eq!(segment.key, "updated-segment");
381    }
382
383    #[test]
384    fn cache_expires() {
385        let store = NullPersistentDataStore { initialized: false };
386
387        let mut wrapper =
388            PersistentDataStoreWrapper::new(Box::new(store), Some(Duration::from_millis(100)));
389
390        let initial_flag = basic_flag("flag");
391        let initial_segment = basic_segment("segment");
392
393        let all_data = AllData {
394            flags: hashmap!["flag".into() => initial_flag],
395            segments: hashmap!["segment".into() => initial_segment],
396        };
397
398        wrapper.init(all_data);
399
400        assert!(wrapper.flag("flag").is_some());
401        assert!(wrapper.segment("segment").is_some());
402
403        std::thread::sleep(Duration::from_millis(750));
404
405        assert!(wrapper.flag("flag").is_none());
406        assert!(wrapper.segment("segment").is_none());
407    }
408
409    #[test]
410    fn cache_that_never_expires_should_update_all_flags_cache_when_flag_is_updated() {
411        let store = NullPersistentDataStore { initialized: false };
412
413        let mut wrapper = PersistentDataStoreWrapper::new(Box::new(store), None);
414
415        let mut initial_flag = basic_flag("flag");
416        initial_flag.version = 1;
417        let initial_segment = basic_segment("segment");
418
419        let all_data = AllData {
420            flags: hashmap!["flag".into() => initial_flag],
421            segments: hashmap!["segment".into() => initial_segment],
422        };
423
424        wrapper.init(all_data);
425
426        let mut updated_flag = basic_flag("flag");
427        updated_flag.version = 2;
428
429        let result = wrapper.upsert_flag("flag", StorageItem::Item(updated_flag));
430        assert!(result.is_ok());
431
432        let flags = wrapper.all_flags();
433        let retrieved_flag = flags.get("flag").unwrap();
434
435        assert_eq!(retrieved_flag.version, 2);
436    }
437
438    #[test]
439    fn cache_that_never_expires_should_update_all_segments_cache_when_segment_is_updated() {
440        let store = NullPersistentDataStore { initialized: false };
441
442        let mut wrapper = PersistentDataStoreWrapper::new(Box::new(store), None);
443
444        let initial_flag = basic_flag("flag");
445        let mut initial_segment = basic_segment("segment");
446        initial_segment.version = 1;
447
448        let all_data = AllData {
449            flags: hashmap!["flag".into() => initial_flag],
450            segments: hashmap!["segment".into() => initial_segment],
451        };
452
453        wrapper.init(all_data);
454
455        let mut updated_segment = basic_segment("segment");
456        updated_segment.version = 2;
457
458        let result = wrapper.upsert_segment("segment", StorageItem::Item(updated_segment));
459        assert!(result.is_ok());
460
461        let segments = wrapper.segments.get_all().unwrap();
462        let retrieved_segment = segments.get("segment").unwrap();
463
464        match retrieved_segment {
465            StorageItem::Item(segment) => assert_eq!(segment.version, 2),
466            _ => panic!("Failed to retrieve correct segment"),
467        };
468    }
469}