1use std::collections::BTreeMap;
11use std::fmt;
12use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
13use std::sync::{Arc, RwLock};
14use std::time::{Duration, Instant};
15
16use async_trait::async_trait;
17use mz_repr::CatalogItemId;
18
19use crate::{CachingPolicy, SecretsReader};
20
21pub const DEFAULT_TTL_SECS: u64 = Duration::from_secs(300).as_secs();
23
24#[derive(Debug)]
25struct CachingParameters {
26 enabled: AtomicBool,
28 ttl_secs: AtomicU64,
30}
31
32impl CachingParameters {
33 fn enabled(&self) -> bool {
34 self.enabled.load(Ordering::Relaxed)
35 }
36
37 fn set_enabled(&self, enabled: bool) -> bool {
38 self.enabled.swap(enabled, Ordering::Relaxed)
39 }
40
41 fn ttl(&self) -> Duration {
42 let secs = self.ttl_secs.load(Ordering::Relaxed);
43 Duration::from_secs(secs)
44 }
45
46 fn set_ttl(&self, ttl: Duration) -> Duration {
47 let prev = self.ttl_secs.swap(ttl.as_secs(), Ordering::Relaxed);
48 Duration::from_secs(prev)
49 }
50}
51
52impl Default for CachingParameters {
53 fn default() -> Self {
54 CachingParameters {
55 enabled: AtomicBool::new(true),
56 ttl_secs: AtomicU64::new(DEFAULT_TTL_SECS),
57 }
58 }
59}
60
61struct CacheItem {
65 secret: Vec<u8>,
66 ts: Instant,
67}
68
69impl CacheItem {
70 fn new(secret: Vec<u8>, ts: Instant) -> Self {
71 CacheItem { secret, ts }
72 }
73}
74
75impl fmt::Debug for CacheItem {
76 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
77 f.debug_struct("CacheItem")
78 .field("secret", &"( ... )")
79 .field("ts", &self.ts)
80 .finish()
81 }
82}
83
84#[derive(Clone, Debug)]
85pub struct CachingSecretsReader {
86 inner: Arc<dyn SecretsReader>,
88 cache: Arc<RwLock<BTreeMap<CatalogItemId, CacheItem>>>,
91 policy: Arc<CachingParameters>,
93}
94
95impl CachingSecretsReader {
96 pub fn new(reader: Arc<dyn SecretsReader>) -> Self {
97 CachingSecretsReader {
98 inner: reader,
99 cache: Arc::new(RwLock::new(BTreeMap::new())),
100 policy: Arc::new(CachingParameters::default()),
101 }
102 }
103
104 pub fn set_policy(&self, policy: CachingPolicy) {
105 if policy.enabled {
106 let prev = self.enable_caching();
107 tracing::info!("Enabling secrets caching, previously enabled {prev}");
108 } else {
109 let prev = self.disable_caching();
110 tracing::info!("Disabling secrets caching, previously enabled {prev}");
111 }
112
113 let prev_ttl = self.set_ttl(policy.ttl);
114 if prev_ttl != policy.ttl {
115 tracing::info!(
116 "Updated secrets caching TTL, new {} seconds, prev {} seconds",
117 policy.ttl.as_secs(),
118 prev_ttl.as_secs()
119 );
120 }
121 }
122
123 fn enable_caching(&self) -> bool {
125 self.policy.set_enabled(true)
126 }
127
128 fn disable_caching(&self) -> bool {
130 let was_enabled = self.policy.set_enabled(false);
132 self.cache
133 .write()
134 .expect("CachingSecretsReader panicked!")
135 .clear();
136
137 was_enabled
138 }
139
140 fn set_ttl(&self, ttl: Duration) -> Duration {
142 self.policy.set_ttl(ttl)
143 }
144}
145
146#[async_trait]
147impl SecretsReader for CachingSecretsReader {
148 async fn read(&self, id: CatalogItemId) -> Result<Vec<u8>, anyhow::Error> {
149 if self.policy.enabled() {
151 let read_guard = self.cache.read().expect("CachingSecretsReader panicked!");
152 let ttl = self.policy.ttl();
153
154 if let Some(CacheItem { secret, ts }) = read_guard.get(&id) {
156 if Instant::now().duration_since(*ts) < ttl {
157 return Ok(secret.clone());
158 }
159 }
160 }
161
162 let value = self.inner.read(id).await?;
164
165 if self.policy.enabled() {
167 let cache_value = CacheItem::new(value.clone(), Instant::now());
168 self.cache
169 .write()
170 .expect("CachingSecretsReader panicked!")
171 .insert(id, cache_value);
172 }
173
174 Ok(value)
175 }
176}
177
178#[cfg(test)]
179mod test {
180 use std::sync::{Arc, Mutex};
181 use std::time::Duration;
182
183 use async_trait::async_trait;
184 use mz_repr::CatalogItemId;
185
186 use crate::cache::CachingSecretsReader;
187 use crate::{InMemorySecretsController, SecretsController, SecretsReader};
188
189 #[mz_ore::test(tokio::test)]
190 async fn test_read_from_cache() {
191 let controller = InMemorySecretsController::new();
192 let testing_reader = TestingSecretsReader::new(controller.reader());
193 let caching_reader = CachingSecretsReader::new(Arc::new(testing_reader.clone()));
194
195 let secret = [42, 42, 42, 42];
196 let id = CatalogItemId::User(1);
197
198 controller
200 .ensure(CatalogItemId::User(1), &secret[..])
201 .await
202 .expect("success");
203 let roundtrip = caching_reader.read(id).await.expect("success");
204
205 assert_eq!(roundtrip, secret.to_vec());
207
208 let roundtrip2 = caching_reader.read(id).await.expect("success");
210 assert_eq!(roundtrip2, secret.to_vec());
211
212 let reads = testing_reader.drain();
213 assert_eq!(reads.len(), 1);
214
215 assert_eq!(reads[0], id);
217 }
218
219 #[mz_ore::test(tokio::test)]
220 async fn test_reading_expired_secret() {
221 let controller = InMemorySecretsController::new();
222 let testing_reader = TestingSecretsReader::new(controller.reader());
223 let caching_reader = CachingSecretsReader::new(Arc::new(testing_reader.clone()));
224
225 caching_reader.set_ttl(Duration::from_secs(1));
227
228 let secret = [42, 42, 42, 42];
229 let id = CatalogItemId::User(1);
230
231 controller.ensure(id, &secret).await.expect("success");
233 caching_reader.read(id).await.expect("success");
235
236 std::thread::sleep(Duration::from_secs(2));
238
239 caching_reader.read(id).await.expect("success");
241
242 let reads = testing_reader.drain();
243 assert_eq!(reads.len(), 2);
244
245 assert_eq!(reads[0], id);
247 assert_eq!(reads[1], id);
248 }
249
250 #[mz_ore::test(tokio::test)]
251 async fn test_disabling_cache() {
252 let controller = InMemorySecretsController::new();
253 let testing_reader = TestingSecretsReader::new(controller.reader());
254 let caching_reader = CachingSecretsReader::new(Arc::new(testing_reader.clone()));
255
256 let secret = [42, 42, 42, 42];
257 let id = CatalogItemId::User(1);
258
259 controller.ensure(id, &secret).await.expect("success");
261
262 caching_reader.read(id).await.expect("success");
264 caching_reader.read(id).await.expect("success");
265
266 let reads = testing_reader.drain();
267 assert_eq!(reads.len(), 1);
268
269 caching_reader.disable_caching();
271
272 caching_reader.read(id).await.expect("success");
274 caching_reader.read(id).await.expect("success");
275
276 let reads = testing_reader.drain();
277 assert_eq!(reads.len(), 2);
278
279 caching_reader.enable_caching();
281
282 caching_reader.read(id).await.expect("success");
284 caching_reader.read(id).await.expect("success");
285
286 let reads = testing_reader.drain();
287 assert_eq!(reads.len(), 1);
288 }
289
290 #[mz_ore::test(tokio::test)]
291 async fn updating_cache_values() {
292 let controller = InMemorySecretsController::new();
293 let testing_reader = TestingSecretsReader::new(controller.reader());
294 let caching_reader = CachingSecretsReader::new(Arc::new(testing_reader.clone()));
295
296 let secret = [42, 42, 42, 42];
297 let id = CatalogItemId::User(1);
298
299 controller.ensure(id, &secret).await.expect("success");
301 caching_reader.read(id).await.expect("success");
303
304 let new_secret = [100, 100];
306 controller.ensure(id, &new_secret).await.expect("success");
307
308 let cached_secret = caching_reader.read(id).await.expect("success");
310 assert_eq!(cached_secret, secret);
311
312 let reads = testing_reader.drain();
314 assert_eq!(reads.len(), 1);
315
316 caching_reader.set_ttl(Duration::from_secs(2));
318 std::thread::sleep(Duration::from_secs(2));
319
320 let read1 = caching_reader.read(id).await.expect("success");
322 let read2 = caching_reader.read(id).await.expect("success");
323 assert_eq!(read1, new_secret);
324 assert_eq!(read1, read2);
325
326 let reads = testing_reader.drain();
328 assert_eq!(reads.len(), 1);
329 }
330
331 #[derive(Debug, Clone)]
334 pub struct TestingSecretsReader {
335 reader: Arc<dyn SecretsReader>,
337 reads: Arc<Mutex<Vec<CatalogItemId>>>,
339 }
340
341 impl TestingSecretsReader {
342 pub fn new(reader: Arc<dyn SecretsReader>) -> Self {
343 TestingSecretsReader {
344 reader,
345 reads: Arc::new(Mutex::new(Vec::new())),
346 }
347 }
348
349 pub fn drain(&self) -> Vec<CatalogItemId> {
351 self.reads
352 .lock()
353 .expect("TracingSecretsController panicked!")
354 .drain(..)
355 .collect()
356 }
357
358 fn record(&self, id: CatalogItemId) {
360 self.reads
361 .lock()
362 .expect("TracingSecretsController panicked!")
363 .push(id);
364 }
365 }
366
367 #[async_trait]
368 impl SecretsReader for TestingSecretsReader {
369 async fn read(&self, id: CatalogItemId) -> Result<Vec<u8>, anyhow::Error> {
370 let result = self.reader.read(id).await;
371 self.record(id);
372 result
373 }
374 }
375}