use std::collections::BTreeMap;
use std::fmt;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant};
use async_trait::async_trait;
use mz_repr::GlobalId;
use crate::{CachingPolicy, SecretsReader};
pub const DEFAULT_TTL_SECS: AtomicU64 = AtomicU64::new(Duration::from_secs(300).as_secs());
#[derive(Debug)]
struct CachingParameters {
enabled: AtomicBool,
ttl_secs: AtomicU64,
}
impl CachingParameters {
fn enabled(&self) -> bool {
self.enabled.load(Ordering::Relaxed)
}
fn set_enabled(&self, enabled: bool) -> bool {
self.enabled.swap(enabled, Ordering::Relaxed)
}
fn ttl(&self) -> Duration {
let secs = self.ttl_secs.load(Ordering::Relaxed);
Duration::from_secs(secs)
}
fn set_ttl(&self, ttl: Duration) -> Duration {
let prev = self.ttl_secs.swap(ttl.as_secs(), Ordering::Relaxed);
Duration::from_secs(prev)
}
}
impl Default for CachingParameters {
fn default() -> Self {
CachingParameters {
enabled: AtomicBool::new(true),
ttl_secs: DEFAULT_TTL_SECS,
}
}
}
struct CacheItem {
secret: Vec<u8>,
ts: Instant,
}
impl CacheItem {
fn new(secret: Vec<u8>, ts: Instant) -> Self {
CacheItem { secret, ts }
}
}
impl fmt::Debug for CacheItem {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("CacheItem")
.field("secret", &"( ... )")
.field("ts", &self.ts)
.finish()
}
}
#[derive(Clone, Debug)]
pub struct CachingSecretsReader {
inner: Arc<dyn SecretsReader>,
cache: Arc<RwLock<BTreeMap<GlobalId, CacheItem>>>,
policy: Arc<CachingParameters>,
}
impl CachingSecretsReader {
pub fn new(reader: Arc<dyn SecretsReader>) -> Self {
CachingSecretsReader {
inner: reader,
cache: Arc::new(RwLock::new(BTreeMap::new())),
policy: Arc::new(CachingParameters::default()),
}
}
pub fn set_policy(&self, policy: CachingPolicy) {
if policy.enabled {
let prev = self.enable_caching();
tracing::info!("Enabling secrets caching, previously enabled {prev}");
} else {
let prev = self.disable_caching();
tracing::info!("Disabling secrets caching, previously enabled {prev}");
}
let prev_ttl = self.set_ttl(policy.ttl);
if prev_ttl != policy.ttl {
tracing::info!(
"Updated secrets caching TTL, new {} seconds, prev {} seconds",
policy.ttl.as_secs(),
prev_ttl.as_secs()
);
}
}
fn enable_caching(&self) -> bool {
self.policy.set_enabled(true)
}
fn disable_caching(&self) -> bool {
let was_enabled = self.policy.set_enabled(false);
self.cache
.write()
.expect("CachingSecretsReader panicked!")
.clear();
was_enabled
}
fn set_ttl(&self, ttl: Duration) -> Duration {
self.policy.set_ttl(ttl)
}
}
#[async_trait]
impl SecretsReader for CachingSecretsReader {
async fn read(&self, id: GlobalId) -> Result<Vec<u8>, anyhow::Error> {
if self.policy.enabled() {
let read_guard = self.cache.read().expect("CachingSecretsReader panicked!");
let ttl = self.policy.ttl();
if let Some(CacheItem { secret, ts }) = read_guard.get(&id) {
if Instant::now().duration_since(*ts) < ttl {
return Ok(secret.clone());
}
}
}
let value = self.inner.read(id).await?;
if self.policy.enabled() {
let cache_value = CacheItem::new(value.clone(), Instant::now());
self.cache
.write()
.expect("CachingSecretsReader panicked!")
.insert(id, cache_value);
}
Ok(value)
}
}
#[cfg(test)]
mod test {
use std::sync::{Arc, Mutex};
use std::time::Duration;
use async_trait::async_trait;
use mz_repr::GlobalId;
use crate::cache::CachingSecretsReader;
use crate::{InMemorySecretsController, SecretsController, SecretsReader};
#[mz_ore::test(tokio::test)]
async fn test_read_from_cache() {
let controller = InMemorySecretsController::new();
let testing_reader = TestingSecretsReader::new(controller.reader());
let caching_reader = CachingSecretsReader::new(Arc::new(testing_reader.clone()));
let secret = [42, 42, 42, 42];
let id = GlobalId::User(1);
controller
.ensure(GlobalId::User(1), &secret[..])
.await
.expect("success");
let roundtrip = caching_reader.read(id).await.expect("success");
assert_eq!(roundtrip, secret.to_vec());
let roundtrip2 = caching_reader.read(id).await.expect("success");
assert_eq!(roundtrip2, secret.to_vec());
let reads = testing_reader.drain();
assert_eq!(reads.len(), 1);
assert_eq!(reads[0], id);
}
#[mz_ore::test(tokio::test)]
async fn test_reading_expired_secret() {
let controller = InMemorySecretsController::new();
let testing_reader = TestingSecretsReader::new(controller.reader());
let caching_reader = CachingSecretsReader::new(Arc::new(testing_reader.clone()));
caching_reader.set_ttl(Duration::from_secs(1));
let secret = [42, 42, 42, 42];
let id = GlobalId::User(1);
controller.ensure(id, &secret).await.expect("success");
caching_reader.read(id).await.expect("success");
std::thread::sleep(Duration::from_secs(2));
caching_reader.read(id).await.expect("success");
let reads = testing_reader.drain();
assert_eq!(reads.len(), 2);
assert_eq!(reads[0], id);
assert_eq!(reads[1], id);
}
#[mz_ore::test(tokio::test)]
async fn test_disabling_cache() {
let controller = InMemorySecretsController::new();
let testing_reader = TestingSecretsReader::new(controller.reader());
let caching_reader = CachingSecretsReader::new(Arc::new(testing_reader.clone()));
let secret = [42, 42, 42, 42];
let id = GlobalId::User(1);
controller.ensure(id, &secret).await.expect("success");
caching_reader.read(id).await.expect("success");
caching_reader.read(id).await.expect("success");
let reads = testing_reader.drain();
assert_eq!(reads.len(), 1);
caching_reader.disable_caching();
caching_reader.read(id).await.expect("success");
caching_reader.read(id).await.expect("success");
let reads = testing_reader.drain();
assert_eq!(reads.len(), 2);
caching_reader.enable_caching();
caching_reader.read(id).await.expect("success");
caching_reader.read(id).await.expect("success");
let reads = testing_reader.drain();
assert_eq!(reads.len(), 1);
}
#[mz_ore::test(tokio::test)]
async fn updating_cache_values() {
let controller = InMemorySecretsController::new();
let testing_reader = TestingSecretsReader::new(controller.reader());
let caching_reader = CachingSecretsReader::new(Arc::new(testing_reader.clone()));
let secret = [42, 42, 42, 42];
let id = GlobalId::User(1);
controller.ensure(id, &secret).await.expect("success");
caching_reader.read(id).await.expect("success");
let new_secret = [100, 100];
controller.ensure(id, &new_secret).await.expect("success");
let cached_secret = caching_reader.read(id).await.expect("success");
assert_eq!(cached_secret, secret);
let reads = testing_reader.drain();
assert_eq!(reads.len(), 1);
caching_reader.set_ttl(Duration::from_secs(2));
std::thread::sleep(Duration::from_secs(2));
let read1 = caching_reader.read(id).await.expect("success");
let read2 = caching_reader.read(id).await.expect("success");
assert_eq!(read1, new_secret);
assert_eq!(read1, read2);
let reads = testing_reader.drain();
assert_eq!(reads.len(), 1);
}
#[derive(Debug, Clone)]
pub struct TestingSecretsReader {
reader: Arc<dyn SecretsReader>,
reads: Arc<Mutex<Vec<GlobalId>>>,
}
impl TestingSecretsReader {
pub fn new(reader: Arc<dyn SecretsReader>) -> Self {
TestingSecretsReader {
reader,
reads: Arc::new(Mutex::new(Vec::new())),
}
}
pub fn drain(&self) -> Vec<GlobalId> {
self.reads
.lock()
.expect("TracingSecretsController panicked!")
.drain(..)
.collect()
}
fn record(&self, id: GlobalId) {
self.reads
.lock()
.expect("TracingSecretsController panicked!")
.push(id);
}
}
#[async_trait]
impl SecretsReader for TestingSecretsReader {
async fn read(&self, id: GlobalId) -> Result<Vec<u8>, anyhow::Error> {
let result = self.reader.read(id).await;
self.record(id);
result
}
}
}