Skip to main content

mz_secrets/
cache.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::collections::BTreeMap;
11use std::fmt;
12use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
13use std::sync::{Arc, RwLock};
14use std::time::{Duration, Instant};
15
16use async_trait::async_trait;
17use mz_repr::CatalogItemId;
18
19use crate::{CachingPolicy, SecretsReader};
20
21/// Default "time to live" for a single cache value, represented in __seconds__.
22pub const DEFAULT_TTL_SECS: u64 = Duration::from_secs(300).as_secs();
23
24#[derive(Debug)]
25struct CachingParameters {
26    /// Whether caching is enabled, can be changed at runtime.
27    enabled: AtomicBool,
28    /// Cache values only live for so long.
29    ttl_secs: AtomicU64,
30}
31
32impl CachingParameters {
33    fn enabled(&self) -> bool {
34        self.enabled.load(Ordering::Relaxed)
35    }
36
37    fn set_enabled(&self, enabled: bool) -> bool {
38        self.enabled.swap(enabled, Ordering::Relaxed)
39    }
40
41    fn ttl(&self) -> Duration {
42        let secs = self.ttl_secs.load(Ordering::Relaxed);
43        Duration::from_secs(secs)
44    }
45
46    fn set_ttl(&self, ttl: Duration) -> Duration {
47        let prev = self.ttl_secs.swap(ttl.as_secs(), Ordering::Relaxed);
48        Duration::from_secs(prev)
49    }
50}
51
52impl Default for CachingParameters {
53    fn default() -> Self {
54        CachingParameters {
55            enabled: AtomicBool::new(true),
56            ttl_secs: AtomicU64::new(DEFAULT_TTL_SECS),
57        }
58    }
59}
60
61/// Values we store in the cache.
62///
63/// Note: we manually implement Debug to prevent leaking secrets in logs.
64struct CacheItem {
65    secret: Vec<u8>,
66    ts: Instant,
67}
68
69impl CacheItem {
70    fn new(secret: Vec<u8>, ts: Instant) -> Self {
71        CacheItem { secret, ts }
72    }
73}
74
75impl fmt::Debug for CacheItem {
76    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
77        f.debug_struct("CacheItem")
78            .field("secret", &"( ... )")
79            .field("ts", &self.ts)
80            .finish()
81    }
82}
83
84#[derive(Clone, Debug)]
85pub struct CachingSecretsReader {
86    /// The underlying secrets, source of truth.
87    inner: Arc<dyn SecretsReader>,
88    /// In-memory cache, not having a size limit or eviction policy is okay because we limit users
89    /// to 100 secrets, which should not be a problem to store in-memory.
90    cache: Arc<RwLock<BTreeMap<CatalogItemId, CacheItem>>>,
91    /// Caching policy, can change at runtime, e.g. via LaunchDarkly.
92    policy: Arc<CachingParameters>,
93}
94
95impl CachingSecretsReader {
96    pub fn new(reader: Arc<dyn SecretsReader>) -> Self {
97        CachingSecretsReader {
98            inner: reader,
99            cache: Arc::new(RwLock::new(BTreeMap::new())),
100            policy: Arc::new(CachingParameters::default()),
101        }
102    }
103
104    pub fn set_policy(&self, policy: CachingPolicy) {
105        if policy.enabled {
106            let prev = self.enable_caching();
107            tracing::info!("Enabling secrets caching, previously enabled {prev}");
108        } else {
109            let prev = self.disable_caching();
110            tracing::info!("Disabling secrets caching, previously enabled {prev}");
111        }
112
113        let prev_ttl = self.set_ttl(policy.ttl);
114        if prev_ttl != policy.ttl {
115            tracing::info!(
116                "Updated secrets caching TTL, new {} seconds, prev {} seconds",
117                policy.ttl.as_secs(),
118                prev_ttl.as_secs()
119            );
120        }
121    }
122
123    /// Enables caching, returning whether we were previously enabled.
124    fn enable_caching(&self) -> bool {
125        self.policy.set_enabled(true)
126    }
127
128    /// Disables caching, returning whether we were previously enabled.
129    fn disable_caching(&self) -> bool {
130        // Disable and clear the cache of all existing values.
131        let was_enabled = self.policy.set_enabled(false);
132        self.cache
133            .write()
134            .expect("CachingSecretsReader panicked!")
135            .clear();
136
137        was_enabled
138    }
139
140    /// Sets a new "time to live" for cache values, returning the old TTL.
141    fn set_ttl(&self, ttl: Duration) -> Duration {
142        self.policy.set_ttl(ttl)
143    }
144
145    /// Invalidates a single cached secret, returning whether the cache contained it.
146    pub fn invalidate(&self, id: CatalogItemId) -> bool {
147        self.cache
148            .write()
149            .expect("CachingSecretsReader panicked!")
150            .remove(&id)
151            .is_some()
152    }
153}
154
155#[async_trait]
156impl SecretsReader for CachingSecretsReader {
157    async fn read(&self, id: CatalogItemId) -> Result<Vec<u8>, anyhow::Error> {
158        // Iff our cache is enabled will we read from it.
159        if self.policy.enabled() {
160            let read_guard = self.cache.read().expect("CachingSecretsReader panicked!");
161            let ttl = self.policy.ttl();
162
163            // If we have a cached value we still need to check if it's expired.
164            if let Some(CacheItem { secret, ts }) = read_guard.get(&id) {
165                if Instant::now().duration_since(*ts) < ttl {
166                    return Ok(secret.clone());
167                }
168            }
169        }
170
171        // Otherwise, we need to read from source!
172        let value = self.inner.read(id).await?;
173
174        // Cache it, if caching is enabled.
175        if self.policy.enabled() {
176            let cache_value = CacheItem::new(value.clone(), Instant::now());
177            self.cache
178                .write()
179                .expect("CachingSecretsReader panicked!")
180                .insert(id, cache_value);
181        }
182
183        Ok(value)
184    }
185}
186
187#[cfg(test)]
188mod test {
189    use std::sync::{Arc, Mutex};
190    use std::time::Duration;
191
192    use async_trait::async_trait;
193    use mz_repr::CatalogItemId;
194
195    use crate::cache::CachingSecretsReader;
196    use crate::{InMemorySecretsController, SecretsController, SecretsReader};
197
198    #[mz_ore::test(tokio::test)]
199    async fn test_read_from_cache() {
200        let controller = InMemorySecretsController::new();
201        let testing_reader = TestingSecretsReader::new(controller.reader());
202        let caching_reader = CachingSecretsReader::new(Arc::new(testing_reader.clone()));
203
204        let secret = [42, 42, 42, 42];
205        let id = CatalogItemId::User(1);
206
207        // Add a new secret and read it back.
208        controller
209            .ensure(CatalogItemId::User(1), &secret[..])
210            .await
211            .expect("success");
212        let roundtrip = caching_reader.read(id).await.expect("success");
213
214        // The secret should be correct.
215        assert_eq!(roundtrip, secret.to_vec());
216
217        // Read it a second time, our cache should be populated now.
218        let roundtrip2 = caching_reader.read(id).await.expect("success");
219        assert_eq!(roundtrip2, secret.to_vec());
220
221        let reads = testing_reader.drain();
222        assert_eq!(reads.len(), 1);
223
224        // We should only have one read, as the second should have hit the cache.
225        assert_eq!(reads[0], id);
226    }
227
228    #[mz_ore::test(tokio::test)]
229    async fn test_reading_expired_secret() {
230        let controller = InMemorySecretsController::new();
231        let testing_reader = TestingSecretsReader::new(controller.reader());
232        let caching_reader = CachingSecretsReader::new(Arc::new(testing_reader.clone()));
233
234        // Update the caching policy the expire secrets quickly.
235        caching_reader.set_ttl(Duration::from_secs(1));
236
237        let secret = [42, 42, 42, 42];
238        let id = CatalogItemId::User(1);
239
240        // Store our secret.
241        controller.ensure(id, &secret).await.expect("success");
242        // Read it once to populate the cache.
243        caching_reader.read(id).await.expect("success");
244
245        // Wait for it to timeout.
246        std::thread::sleep(Duration::from_secs(2));
247
248        // Read it again, which should hit the underlying source.
249        caching_reader.read(id).await.expect("success");
250
251        let reads = testing_reader.drain();
252        assert_eq!(reads.len(), 2);
253
254        // Our value should have expired so we should have read from the underlying source.
255        assert_eq!(reads[0], id);
256        assert_eq!(reads[1], id);
257    }
258
259    #[mz_ore::test(tokio::test)]
260    async fn test_disabling_cache() {
261        let controller = InMemorySecretsController::new();
262        let testing_reader = TestingSecretsReader::new(controller.reader());
263        let caching_reader = CachingSecretsReader::new(Arc::new(testing_reader.clone()));
264
265        let secret = [42, 42, 42, 42];
266        let id = CatalogItemId::User(1);
267
268        // Store a value.
269        controller.ensure(id, &secret).await.expect("success");
270
271        // Read twice, which should hit the cache the second time.
272        caching_reader.read(id).await.expect("success");
273        caching_reader.read(id).await.expect("success");
274
275        let reads = testing_reader.drain();
276        assert_eq!(reads.len(), 1);
277
278        // Disable caching.
279        caching_reader.disable_caching();
280
281        // Read twice again, both should hit the source.
282        caching_reader.read(id).await.expect("success");
283        caching_reader.read(id).await.expect("success");
284
285        let reads = testing_reader.drain();
286        assert_eq!(reads.len(), 2);
287
288        // Re-enable caching.
289        caching_reader.enable_caching();
290
291        // Read twice again, first should populate the cache, second should hit it.
292        caching_reader.read(id).await.expect("success");
293        caching_reader.read(id).await.expect("success");
294
295        let reads = testing_reader.drain();
296        assert_eq!(reads.len(), 1);
297    }
298
299    #[mz_ore::test(tokio::test)]
300    async fn updating_cache_values() {
301        let controller = InMemorySecretsController::new();
302        let testing_reader = TestingSecretsReader::new(controller.reader());
303        let caching_reader = CachingSecretsReader::new(Arc::new(testing_reader.clone()));
304
305        let secret = [42, 42, 42, 42];
306        let id = CatalogItemId::User(1);
307
308        // Store an initial value.
309        controller.ensure(id, &secret).await.expect("success");
310        // Read to load the value into the cache.
311        caching_reader.read(id).await.expect("success");
312
313        // Update the stored secret.
314        let new_secret = [100, 100];
315        controller.ensure(id, &new_secret).await.expect("success");
316
317        // Reading from the cache should give us the _old_ value.
318        let cached_secret = caching_reader.read(id).await.expect("success");
319        assert_eq!(cached_secret, secret);
320
321        // We should only have registered one read, since we made a stale read from the cache.
322        let reads = testing_reader.drain();
323        assert_eq!(reads.len(), 1);
324
325        // Wait for the secret to expire.
326        caching_reader.set_ttl(Duration::from_secs(2));
327        std::thread::sleep(Duration::from_secs(2));
328
329        // Since the cache value is expired, we should read from source, and get the new value.
330        let read1 = caching_reader.read(id).await.expect("success");
331        let read2 = caching_reader.read(id).await.expect("success");
332        assert_eq!(read1, new_secret);
333        assert_eq!(read1, read2);
334
335        // Should only have one read since we updated the cache.
336        let reads = testing_reader.drain();
337        assert_eq!(reads.len(), 1);
338    }
339
340    #[mz_ore::test(tokio::test)]
341    async fn test_invalidate() {
342        let controller = InMemorySecretsController::new();
343        let testing_reader = TestingSecretsReader::new(controller.reader());
344        let caching_reader = CachingSecretsReader::new(Arc::new(testing_reader.clone()));
345
346        let secret = [42, 42, 42, 42];
347        let id = CatalogItemId::User(1);
348
349        controller.ensure(id, &secret).await.expect("success");
350        caching_reader.read(id).await.expect("success");
351
352        caching_reader.read(id).await.expect("success");
353        let reads = testing_reader.drain();
354        assert_eq!(reads.len(), 1);
355
356        assert!(caching_reader.invalidate(id));
357
358        caching_reader.read(id).await.expect("success");
359        let reads = testing_reader.drain();
360        assert_eq!(reads.len(), 1);
361        assert_eq!(reads[0], id);
362
363        assert!(!caching_reader.invalidate(CatalogItemId::User(999)));
364    }
365
366    /// A "secrets controller" that logs all of the actions it takes and allows us to inject
367    /// failures. Used to test the implementation of our caching secrets controller.
368    #[derive(Debug, Clone)]
369    pub struct TestingSecretsReader {
370        /// The underlying secrets controller.
371        reader: Arc<dyn SecretsReader>,
372        /// A log of reads that have been made.
373        reads: Arc<Mutex<Vec<CatalogItemId>>>,
374    }
375
376    impl TestingSecretsReader {
377        pub fn new(reader: Arc<dyn SecretsReader>) -> Self {
378            TestingSecretsReader {
379                reader,
380                reads: Arc::new(Mutex::new(Vec::new())),
381            }
382        }
383
384        /// Drain all of the actions for introspection.
385        pub fn drain(&self) -> Vec<CatalogItemId> {
386            self.reads
387                .lock()
388                .expect("TracingSecretsController panicked!")
389                .drain(..)
390                .collect()
391        }
392
393        /// Record that an action has occurred.
394        fn record(&self, id: CatalogItemId) {
395            self.reads
396                .lock()
397                .expect("TracingSecretsController panicked!")
398                .push(id);
399        }
400    }
401
402    #[async_trait]
403    impl SecretsReader for TestingSecretsReader {
404        async fn read(&self, id: CatalogItemId) -> Result<Vec<u8>, anyhow::Error> {
405            let result = self.reader.read(id).await;
406            self.record(id);
407            result
408        }
409    }
410}