mz_secrets/
cache.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::collections::BTreeMap;
11use std::fmt;
12use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
13use std::sync::{Arc, RwLock};
14use std::time::{Duration, Instant};
15
16use async_trait::async_trait;
17use mz_repr::CatalogItemId;
18
19use crate::{CachingPolicy, SecretsReader};
20
21/// Default "time to live" for a single cache value, represented in __seconds__.
22pub const DEFAULT_TTL_SECS: u64 = Duration::from_secs(300).as_secs();
23
24#[derive(Debug)]
25struct CachingParameters {
26    /// Whether caching is enabled, can be changed at runtime.
27    enabled: AtomicBool,
28    /// Cache values only live for so long.
29    ttl_secs: AtomicU64,
30}
31
32impl CachingParameters {
33    fn enabled(&self) -> bool {
34        self.enabled.load(Ordering::Relaxed)
35    }
36
37    fn set_enabled(&self, enabled: bool) -> bool {
38        self.enabled.swap(enabled, Ordering::Relaxed)
39    }
40
41    fn ttl(&self) -> Duration {
42        let secs = self.ttl_secs.load(Ordering::Relaxed);
43        Duration::from_secs(secs)
44    }
45
46    fn set_ttl(&self, ttl: Duration) -> Duration {
47        let prev = self.ttl_secs.swap(ttl.as_secs(), Ordering::Relaxed);
48        Duration::from_secs(prev)
49    }
50}
51
52impl Default for CachingParameters {
53    fn default() -> Self {
54        CachingParameters {
55            enabled: AtomicBool::new(true),
56            ttl_secs: AtomicU64::new(DEFAULT_TTL_SECS),
57        }
58    }
59}
60
61/// Values we store in the cache.
62///
63/// Note: we manually implement Debug to prevent leaking secrets in logs.
64struct CacheItem {
65    secret: Vec<u8>,
66    ts: Instant,
67}
68
69impl CacheItem {
70    fn new(secret: Vec<u8>, ts: Instant) -> Self {
71        CacheItem { secret, ts }
72    }
73}
74
75impl fmt::Debug for CacheItem {
76    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
77        f.debug_struct("CacheItem")
78            .field("secret", &"( ... )")
79            .field("ts", &self.ts)
80            .finish()
81    }
82}
83
84#[derive(Clone, Debug)]
85pub struct CachingSecretsReader {
86    /// The underlying secrets, source of truth.
87    inner: Arc<dyn SecretsReader>,
88    /// In-memory cache, not having a size limit or eviction policy is okay because we limit users
89    /// to 100 secrets, which should not be a problem to store in-memory.
90    cache: Arc<RwLock<BTreeMap<CatalogItemId, CacheItem>>>,
91    /// Caching policy, can change at runtime, e.g. via LaunchDarkly.
92    policy: Arc<CachingParameters>,
93}
94
95impl CachingSecretsReader {
96    pub fn new(reader: Arc<dyn SecretsReader>) -> Self {
97        CachingSecretsReader {
98            inner: reader,
99            cache: Arc::new(RwLock::new(BTreeMap::new())),
100            policy: Arc::new(CachingParameters::default()),
101        }
102    }
103
104    pub fn set_policy(&self, policy: CachingPolicy) {
105        if policy.enabled {
106            let prev = self.enable_caching();
107            tracing::info!("Enabling secrets caching, previously enabled {prev}");
108        } else {
109            let prev = self.disable_caching();
110            tracing::info!("Disabling secrets caching, previously enabled {prev}");
111        }
112
113        let prev_ttl = self.set_ttl(policy.ttl);
114        if prev_ttl != policy.ttl {
115            tracing::info!(
116                "Updated secrets caching TTL, new {} seconds, prev {} seconds",
117                policy.ttl.as_secs(),
118                prev_ttl.as_secs()
119            );
120        }
121    }
122
123    /// Enables caching, returning whether we were previously enabled.
124    fn enable_caching(&self) -> bool {
125        self.policy.set_enabled(true)
126    }
127
128    /// Disables caching, returning whether we were previously enabled.
129    fn disable_caching(&self) -> bool {
130        // Disable and clear the cache of all existing values.
131        let was_enabled = self.policy.set_enabled(false);
132        self.cache
133            .write()
134            .expect("CachingSecretsReader panicked!")
135            .clear();
136
137        was_enabled
138    }
139
140    /// Sets a new "time to live" for cache values, returning the old TTL.
141    fn set_ttl(&self, ttl: Duration) -> Duration {
142        self.policy.set_ttl(ttl)
143    }
144}
145
146#[async_trait]
147impl SecretsReader for CachingSecretsReader {
148    async fn read(&self, id: CatalogItemId) -> Result<Vec<u8>, anyhow::Error> {
149        // Iff our cache is enabled will we read from it.
150        if self.policy.enabled() {
151            let read_guard = self.cache.read().expect("CachingSecretsReader panicked!");
152            let ttl = self.policy.ttl();
153
154            // If we have a cached value we still need to check if it's expired.
155            if let Some(CacheItem { secret, ts }) = read_guard.get(&id) {
156                if Instant::now().duration_since(*ts) < ttl {
157                    return Ok(secret.clone());
158                }
159            }
160        }
161
162        // Otherwise, we need to read from source!
163        let value = self.inner.read(id).await?;
164
165        // Cache it, if caching is enabled.
166        if self.policy.enabled() {
167            let cache_value = CacheItem::new(value.clone(), Instant::now());
168            self.cache
169                .write()
170                .expect("CachingSecretsReader panicked!")
171                .insert(id, cache_value);
172        }
173
174        Ok(value)
175    }
176}
177
178#[cfg(test)]
179mod test {
180    use std::sync::{Arc, Mutex};
181    use std::time::Duration;
182
183    use async_trait::async_trait;
184    use mz_repr::CatalogItemId;
185
186    use crate::cache::CachingSecretsReader;
187    use crate::{InMemorySecretsController, SecretsController, SecretsReader};
188
189    #[mz_ore::test(tokio::test)]
190    async fn test_read_from_cache() {
191        let controller = InMemorySecretsController::new();
192        let testing_reader = TestingSecretsReader::new(controller.reader());
193        let caching_reader = CachingSecretsReader::new(Arc::new(testing_reader.clone()));
194
195        let secret = [42, 42, 42, 42];
196        let id = CatalogItemId::User(1);
197
198        // Add a new secret and read it back.
199        controller
200            .ensure(CatalogItemId::User(1), &secret[..])
201            .await
202            .expect("success");
203        let roundtrip = caching_reader.read(id).await.expect("success");
204
205        // The secret should be correct.
206        assert_eq!(roundtrip, secret.to_vec());
207
208        // Read it a second time, our cache should be populated now.
209        let roundtrip2 = caching_reader.read(id).await.expect("success");
210        assert_eq!(roundtrip2, secret.to_vec());
211
212        let reads = testing_reader.drain();
213        assert_eq!(reads.len(), 1);
214
215        // We should only have one read, as the second should have hit the cache.
216        assert_eq!(reads[0], id);
217    }
218
219    #[mz_ore::test(tokio::test)]
220    async fn test_reading_expired_secret() {
221        let controller = InMemorySecretsController::new();
222        let testing_reader = TestingSecretsReader::new(controller.reader());
223        let caching_reader = CachingSecretsReader::new(Arc::new(testing_reader.clone()));
224
225        // Update the caching policy the expire secrets quickly.
226        caching_reader.set_ttl(Duration::from_secs(1));
227
228        let secret = [42, 42, 42, 42];
229        let id = CatalogItemId::User(1);
230
231        // Store our secret.
232        controller.ensure(id, &secret).await.expect("success");
233        // Read it once to populate the cache.
234        caching_reader.read(id).await.expect("success");
235
236        // Wait for it to timeout.
237        std::thread::sleep(Duration::from_secs(2));
238
239        // Read it again, which should hit the underlying source.
240        caching_reader.read(id).await.expect("success");
241
242        let reads = testing_reader.drain();
243        assert_eq!(reads.len(), 2);
244
245        // Our value should have expired so we should have read from the underlying source.
246        assert_eq!(reads[0], id);
247        assert_eq!(reads[1], id);
248    }
249
250    #[mz_ore::test(tokio::test)]
251    async fn test_disabling_cache() {
252        let controller = InMemorySecretsController::new();
253        let testing_reader = TestingSecretsReader::new(controller.reader());
254        let caching_reader = CachingSecretsReader::new(Arc::new(testing_reader.clone()));
255
256        let secret = [42, 42, 42, 42];
257        let id = CatalogItemId::User(1);
258
259        // Store a value.
260        controller.ensure(id, &secret).await.expect("success");
261
262        // Read twice, which should hit the cache the second time.
263        caching_reader.read(id).await.expect("success");
264        caching_reader.read(id).await.expect("success");
265
266        let reads = testing_reader.drain();
267        assert_eq!(reads.len(), 1);
268
269        // Disable caching.
270        caching_reader.disable_caching();
271
272        // Read twice again, both should hit the source.
273        caching_reader.read(id).await.expect("success");
274        caching_reader.read(id).await.expect("success");
275
276        let reads = testing_reader.drain();
277        assert_eq!(reads.len(), 2);
278
279        // Re-enable caching.
280        caching_reader.enable_caching();
281
282        // Read twice again, first should populate the cache, second should hit it.
283        caching_reader.read(id).await.expect("success");
284        caching_reader.read(id).await.expect("success");
285
286        let reads = testing_reader.drain();
287        assert_eq!(reads.len(), 1);
288    }
289
290    #[mz_ore::test(tokio::test)]
291    async fn updating_cache_values() {
292        let controller = InMemorySecretsController::new();
293        let testing_reader = TestingSecretsReader::new(controller.reader());
294        let caching_reader = CachingSecretsReader::new(Arc::new(testing_reader.clone()));
295
296        let secret = [42, 42, 42, 42];
297        let id = CatalogItemId::User(1);
298
299        // Store an initial value.
300        controller.ensure(id, &secret).await.expect("success");
301        // Read to load the value into the cache.
302        caching_reader.read(id).await.expect("success");
303
304        // Update the stored secret.
305        let new_secret = [100, 100];
306        controller.ensure(id, &new_secret).await.expect("success");
307
308        // Reading from the cache should give us the _old_ value.
309        let cached_secret = caching_reader.read(id).await.expect("success");
310        assert_eq!(cached_secret, secret);
311
312        // We should only have registered one read, since we made a stale read from the cache.
313        let reads = testing_reader.drain();
314        assert_eq!(reads.len(), 1);
315
316        // Wait for the secret to expire.
317        caching_reader.set_ttl(Duration::from_secs(2));
318        std::thread::sleep(Duration::from_secs(2));
319
320        // Since the cache value is expired, we should read from source, and get the new value.
321        let read1 = caching_reader.read(id).await.expect("success");
322        let read2 = caching_reader.read(id).await.expect("success");
323        assert_eq!(read1, new_secret);
324        assert_eq!(read1, read2);
325
326        // Should only have one read since we updated the cache.
327        let reads = testing_reader.drain();
328        assert_eq!(reads.len(), 1);
329    }
330
331    /// A "secrets controller" that logs all of the actions it takes and allows us to inject
332    /// failures. Used to test the implementation of our caching secrets controller.
333    #[derive(Debug, Clone)]
334    pub struct TestingSecretsReader {
335        /// The underlying secrets controller.
336        reader: Arc<dyn SecretsReader>,
337        /// A log of reads that have been made.
338        reads: Arc<Mutex<Vec<CatalogItemId>>>,
339    }
340
341    impl TestingSecretsReader {
342        pub fn new(reader: Arc<dyn SecretsReader>) -> Self {
343            TestingSecretsReader {
344                reader,
345                reads: Arc::new(Mutex::new(Vec::new())),
346            }
347        }
348
349        /// Drain all of the actions for introspection.
350        pub fn drain(&self) -> Vec<CatalogItemId> {
351            self.reads
352                .lock()
353                .expect("TracingSecretsController panicked!")
354                .drain(..)
355                .collect()
356        }
357
358        /// Record that an action has occurred.
359        fn record(&self, id: CatalogItemId) {
360            self.reads
361                .lock()
362                .expect("TracingSecretsController panicked!")
363                .push(id);
364        }
365    }
366
367    #[async_trait]
368    impl SecretsReader for TestingSecretsReader {
369        async fn read(&self, id: CatalogItemId) -> Result<Vec<u8>, anyhow::Error> {
370            let result = self.reader.read(id).await;
371            self.record(id);
372            result
373        }
374    }
375}