1use std::collections::BTreeMap;
11use std::fmt;
12use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
13use std::sync::{Arc, RwLock};
14use std::time::{Duration, Instant};
15
16use async_trait::async_trait;
17use mz_repr::CatalogItemId;
18
19use crate::{CachingPolicy, SecretsReader};
20
21pub const DEFAULT_TTL_SECS: u64 = Duration::from_secs(300).as_secs();
23
24#[derive(Debug)]
25struct CachingParameters {
26 enabled: AtomicBool,
28 ttl_secs: AtomicU64,
30}
31
32impl CachingParameters {
33 fn enabled(&self) -> bool {
34 self.enabled.load(Ordering::Relaxed)
35 }
36
37 fn set_enabled(&self, enabled: bool) -> bool {
38 self.enabled.swap(enabled, Ordering::Relaxed)
39 }
40
41 fn ttl(&self) -> Duration {
42 let secs = self.ttl_secs.load(Ordering::Relaxed);
43 Duration::from_secs(secs)
44 }
45
46 fn set_ttl(&self, ttl: Duration) -> Duration {
47 let prev = self.ttl_secs.swap(ttl.as_secs(), Ordering::Relaxed);
48 Duration::from_secs(prev)
49 }
50}
51
52impl Default for CachingParameters {
53 fn default() -> Self {
54 CachingParameters {
55 enabled: AtomicBool::new(true),
56 ttl_secs: AtomicU64::new(DEFAULT_TTL_SECS),
57 }
58 }
59}
60
61struct CacheItem {
65 secret: Vec<u8>,
66 ts: Instant,
67}
68
69impl CacheItem {
70 fn new(secret: Vec<u8>, ts: Instant) -> Self {
71 CacheItem { secret, ts }
72 }
73}
74
75impl fmt::Debug for CacheItem {
76 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
77 f.debug_struct("CacheItem")
78 .field("secret", &"( ... )")
79 .field("ts", &self.ts)
80 .finish()
81 }
82}
83
84#[derive(Clone, Debug)]
85pub struct CachingSecretsReader {
86 inner: Arc<dyn SecretsReader>,
88 cache: Arc<RwLock<BTreeMap<CatalogItemId, CacheItem>>>,
91 policy: Arc<CachingParameters>,
93}
94
95impl CachingSecretsReader {
96 pub fn new(reader: Arc<dyn SecretsReader>) -> Self {
97 CachingSecretsReader {
98 inner: reader,
99 cache: Arc::new(RwLock::new(BTreeMap::new())),
100 policy: Arc::new(CachingParameters::default()),
101 }
102 }
103
104 pub fn set_policy(&self, policy: CachingPolicy) {
105 if policy.enabled {
106 let prev = self.enable_caching();
107 tracing::info!("Enabling secrets caching, previously enabled {prev}");
108 } else {
109 let prev = self.disable_caching();
110 tracing::info!("Disabling secrets caching, previously enabled {prev}");
111 }
112
113 let prev_ttl = self.set_ttl(policy.ttl);
114 if prev_ttl != policy.ttl {
115 tracing::info!(
116 "Updated secrets caching TTL, new {} seconds, prev {} seconds",
117 policy.ttl.as_secs(),
118 prev_ttl.as_secs()
119 );
120 }
121 }
122
123 fn enable_caching(&self) -> bool {
125 self.policy.set_enabled(true)
126 }
127
128 fn disable_caching(&self) -> bool {
130 let was_enabled = self.policy.set_enabled(false);
132 self.cache
133 .write()
134 .expect("CachingSecretsReader panicked!")
135 .clear();
136
137 was_enabled
138 }
139
140 fn set_ttl(&self, ttl: Duration) -> Duration {
142 self.policy.set_ttl(ttl)
143 }
144
145 pub fn invalidate(&self, id: CatalogItemId) -> bool {
147 self.cache
148 .write()
149 .expect("CachingSecretsReader panicked!")
150 .remove(&id)
151 .is_some()
152 }
153}
154
155#[async_trait]
156impl SecretsReader for CachingSecretsReader {
157 async fn read(&self, id: CatalogItemId) -> Result<Vec<u8>, anyhow::Error> {
158 if self.policy.enabled() {
160 let read_guard = self.cache.read().expect("CachingSecretsReader panicked!");
161 let ttl = self.policy.ttl();
162
163 if let Some(CacheItem { secret, ts }) = read_guard.get(&id) {
165 if Instant::now().duration_since(*ts) < ttl {
166 return Ok(secret.clone());
167 }
168 }
169 }
170
171 let value = self.inner.read(id).await?;
173
174 if self.policy.enabled() {
176 let cache_value = CacheItem::new(value.clone(), Instant::now());
177 self.cache
178 .write()
179 .expect("CachingSecretsReader panicked!")
180 .insert(id, cache_value);
181 }
182
183 Ok(value)
184 }
185}
186
187#[cfg(test)]
188mod test {
189 use std::sync::{Arc, Mutex};
190 use std::time::Duration;
191
192 use async_trait::async_trait;
193 use mz_repr::CatalogItemId;
194
195 use crate::cache::CachingSecretsReader;
196 use crate::{InMemorySecretsController, SecretsController, SecretsReader};
197
198 #[mz_ore::test(tokio::test)]
199 async fn test_read_from_cache() {
200 let controller = InMemorySecretsController::new();
201 let testing_reader = TestingSecretsReader::new(controller.reader());
202 let caching_reader = CachingSecretsReader::new(Arc::new(testing_reader.clone()));
203
204 let secret = [42, 42, 42, 42];
205 let id = CatalogItemId::User(1);
206
207 controller
209 .ensure(CatalogItemId::User(1), &secret[..])
210 .await
211 .expect("success");
212 let roundtrip = caching_reader.read(id).await.expect("success");
213
214 assert_eq!(roundtrip, secret.to_vec());
216
217 let roundtrip2 = caching_reader.read(id).await.expect("success");
219 assert_eq!(roundtrip2, secret.to_vec());
220
221 let reads = testing_reader.drain();
222 assert_eq!(reads.len(), 1);
223
224 assert_eq!(reads[0], id);
226 }
227
228 #[mz_ore::test(tokio::test)]
229 async fn test_reading_expired_secret() {
230 let controller = InMemorySecretsController::new();
231 let testing_reader = TestingSecretsReader::new(controller.reader());
232 let caching_reader = CachingSecretsReader::new(Arc::new(testing_reader.clone()));
233
234 caching_reader.set_ttl(Duration::from_secs(1));
236
237 let secret = [42, 42, 42, 42];
238 let id = CatalogItemId::User(1);
239
240 controller.ensure(id, &secret).await.expect("success");
242 caching_reader.read(id).await.expect("success");
244
245 std::thread::sleep(Duration::from_secs(2));
247
248 caching_reader.read(id).await.expect("success");
250
251 let reads = testing_reader.drain();
252 assert_eq!(reads.len(), 2);
253
254 assert_eq!(reads[0], id);
256 assert_eq!(reads[1], id);
257 }
258
259 #[mz_ore::test(tokio::test)]
260 async fn test_disabling_cache() {
261 let controller = InMemorySecretsController::new();
262 let testing_reader = TestingSecretsReader::new(controller.reader());
263 let caching_reader = CachingSecretsReader::new(Arc::new(testing_reader.clone()));
264
265 let secret = [42, 42, 42, 42];
266 let id = CatalogItemId::User(1);
267
268 controller.ensure(id, &secret).await.expect("success");
270
271 caching_reader.read(id).await.expect("success");
273 caching_reader.read(id).await.expect("success");
274
275 let reads = testing_reader.drain();
276 assert_eq!(reads.len(), 1);
277
278 caching_reader.disable_caching();
280
281 caching_reader.read(id).await.expect("success");
283 caching_reader.read(id).await.expect("success");
284
285 let reads = testing_reader.drain();
286 assert_eq!(reads.len(), 2);
287
288 caching_reader.enable_caching();
290
291 caching_reader.read(id).await.expect("success");
293 caching_reader.read(id).await.expect("success");
294
295 let reads = testing_reader.drain();
296 assert_eq!(reads.len(), 1);
297 }
298
299 #[mz_ore::test(tokio::test)]
300 async fn updating_cache_values() {
301 let controller = InMemorySecretsController::new();
302 let testing_reader = TestingSecretsReader::new(controller.reader());
303 let caching_reader = CachingSecretsReader::new(Arc::new(testing_reader.clone()));
304
305 let secret = [42, 42, 42, 42];
306 let id = CatalogItemId::User(1);
307
308 controller.ensure(id, &secret).await.expect("success");
310 caching_reader.read(id).await.expect("success");
312
313 let new_secret = [100, 100];
315 controller.ensure(id, &new_secret).await.expect("success");
316
317 let cached_secret = caching_reader.read(id).await.expect("success");
319 assert_eq!(cached_secret, secret);
320
321 let reads = testing_reader.drain();
323 assert_eq!(reads.len(), 1);
324
325 caching_reader.set_ttl(Duration::from_secs(2));
327 std::thread::sleep(Duration::from_secs(2));
328
329 let read1 = caching_reader.read(id).await.expect("success");
331 let read2 = caching_reader.read(id).await.expect("success");
332 assert_eq!(read1, new_secret);
333 assert_eq!(read1, read2);
334
335 let reads = testing_reader.drain();
337 assert_eq!(reads.len(), 1);
338 }
339
340 #[mz_ore::test(tokio::test)]
341 async fn test_invalidate() {
342 let controller = InMemorySecretsController::new();
343 let testing_reader = TestingSecretsReader::new(controller.reader());
344 let caching_reader = CachingSecretsReader::new(Arc::new(testing_reader.clone()));
345
346 let secret = [42, 42, 42, 42];
347 let id = CatalogItemId::User(1);
348
349 controller.ensure(id, &secret).await.expect("success");
350 caching_reader.read(id).await.expect("success");
351
352 caching_reader.read(id).await.expect("success");
353 let reads = testing_reader.drain();
354 assert_eq!(reads.len(), 1);
355
356 assert!(caching_reader.invalidate(id));
357
358 caching_reader.read(id).await.expect("success");
359 let reads = testing_reader.drain();
360 assert_eq!(reads.len(), 1);
361 assert_eq!(reads[0], id);
362
363 assert!(!caching_reader.invalidate(CatalogItemId::User(999)));
364 }
365
366 #[derive(Debug, Clone)]
369 pub struct TestingSecretsReader {
370 reader: Arc<dyn SecretsReader>,
372 reads: Arc<Mutex<Vec<CatalogItemId>>>,
374 }
375
376 impl TestingSecretsReader {
377 pub fn new(reader: Arc<dyn SecretsReader>) -> Self {
378 TestingSecretsReader {
379 reader,
380 reads: Arc::new(Mutex::new(Vec::new())),
381 }
382 }
383
384 pub fn drain(&self) -> Vec<CatalogItemId> {
386 self.reads
387 .lock()
388 .expect("TracingSecretsController panicked!")
389 .drain(..)
390 .collect()
391 }
392
393 fn record(&self, id: CatalogItemId) {
395 self.reads
396 .lock()
397 .expect("TracingSecretsController panicked!")
398 .push(id);
399 }
400 }
401
402 #[async_trait]
403 impl SecretsReader for TestingSecretsReader {
404 async fn read(&self, id: CatalogItemId) -> Result<Vec<u8>, anyhow::Error> {
405 let result = self.reader.read(id).await;
406 self.record(id);
407 result
408 }
409 }
410}