use std::collections::BTreeMap;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use anyhow::Context;
use chrono::{DateTime, Utc};
use derivative::Derivative;
use mz_ore::cast::CastFrom;
use mz_repr::{Datum, Diff, Row, RowArena, Timestamp};
use mz_secrets::cache::CachingSecretsReader;
use mz_secrets::SecretsReader;
use mz_sql::plan::{WebhookBodyFormat, WebhookHeaders, WebhookValidation, WebhookValidationSecret};
use mz_storage_client::controller::MonotonicAppender;
use mz_storage_client::statistics::WebhookStatistics;
use mz_storage_types::controller::StorageError;
use tokio::sync::Semaphore;
use crate::optimize::dataflows::{prep_scalar_expr, ExprPrepStyle};
#[derive(thiserror::Error, Debug)]
pub enum AppendWebhookError {
#[error("could not read a required secret")]
MissingSecret,
#[error("the provided request body is not UTF-8: {msg}")]
InvalidUtf8Body { msg: String },
#[error("the provided request body is not valid JSON: {msg}")]
InvalidJsonBody { msg: String },
#[error("webhook source '{database}.{schema}.{name}' does not exist")]
UnknownWebhook {
database: String,
schema: String,
name: String,
},
#[error("failed to validate the request")]
ValidationFailed,
#[error("validation error")]
ValidationError,
#[error("internal channel closed")]
ChannelClosed,
#[error("internal error: {0:?}")]
InternalError(#[from] anyhow::Error),
#[error("internal storage failure! {0:?}")]
StorageError(#[from] StorageError<mz_repr::Timestamp>),
}
#[derive(Clone)]
pub struct AppendWebhookValidator {
validation: WebhookValidation,
secrets_reader: CachingSecretsReader,
}
impl AppendWebhookValidator {
pub fn new(validation: WebhookValidation, secrets_reader: CachingSecretsReader) -> Self {
AppendWebhookValidator {
validation,
secrets_reader,
}
}
pub async fn eval(
self,
body: bytes::Bytes,
headers: Arc<BTreeMap<String, String>>,
received_at: DateTime<Utc>,
) -> Result<bool, AppendWebhookError> {
let AppendWebhookValidator {
validation,
secrets_reader,
} = self;
let WebhookValidation {
mut expression,
relation_desc: _,
secrets,
bodies: body_columns,
headers: header_columns,
} = validation;
let mut secret_contents = BTreeMap::new();
for WebhookValidationSecret {
id,
column_idx,
use_bytes,
} in secrets
{
let secret = secrets_reader
.read(id)
.await
.map_err(|_| AppendWebhookError::MissingSecret)?;
secret_contents.insert(column_idx, (secret, use_bytes));
}
prep_scalar_expr(
&mut expression,
ExprPrepStyle::WebhookValidation { now: received_at },
)
.map_err(|err| {
tracing::error!(?err, "failed to evaluate current time");
AppendWebhookError::ValidationError
})?;
let validate = move || {
let temp_storage = RowArena::default();
let mut datums = Vec::with_capacity(
body_columns.len() + header_columns.len() + secret_contents.len(),
);
for (column_idx, use_bytes) in body_columns {
assert_eq!(column_idx, datums.len(), "body index and datums mismatch!");
let datum = if use_bytes {
Datum::Bytes(&body[..])
} else {
let s = std::str::from_utf8(&body[..])
.map_err(|m| AppendWebhookError::InvalidUtf8Body { msg: m.to_string() })?;
Datum::String(s)
};
datums.push(datum);
}
let headers_byte = std::cell::OnceCell::new();
let headers_text = std::cell::OnceCell::new();
for (column_idx, use_bytes) in header_columns {
assert_eq!(column_idx, datums.len(), "index and datums mismatch!");
let row = if use_bytes {
headers_byte.get_or_init(|| {
let mut row = Row::with_capacity(1);
let mut packer = row.packer();
packer.push_dict(
headers
.iter()
.map(|(name, val)| (name.as_str(), Datum::Bytes(val.as_bytes()))),
);
row
})
} else {
headers_text.get_or_init(|| {
let mut row = Row::with_capacity(1);
let mut packer = row.packer();
packer.push_dict(
headers
.iter()
.map(|(name, val)| (name.as_str(), Datum::String(val))),
);
row
})
};
datums.push(row.unpack_first());
}
for column_idx in datums.len()..datums.len() + secret_contents.len() {
let (secret, use_bytes) = secret_contents
.get(&column_idx)
.expect("more secrets to provide, but none for the next column");
if *use_bytes {
datums.push(Datum::Bytes(secret));
} else {
let secret_str = std::str::from_utf8(&secret[..]).expect("valid UTF-8");
datums.push(Datum::String(secret_str));
}
}
let valid = expression
.eval(&datums[..], &temp_storage)
.map_err(|_| AppendWebhookError::ValidationError)?;
match valid {
Datum::True => Ok::<_, AppendWebhookError>(true),
Datum::False | Datum::Null => Ok(false),
_ => unreachable!("Creating a webhook source asserts we return a boolean"),
}
};
let valid = mz_ore::task::spawn_blocking(
|| "webhook-validator-expr",
move || {
mz_ore::panic::catch_unwind(validate).map_err(|_| {
tracing::error!("panic while validating webhook request!");
AppendWebhookError::ValidationError
})
},
)
.await
.context("joining on validation")
.map_err(|e| {
tracing::error!("Failed to run validation for webhook, {e}");
AppendWebhookError::ValidationError
})??;
valid
}
}
#[derive(Derivative, Clone)]
#[derivative(Debug)]
pub struct AppendWebhookResponse {
pub tx: WebhookAppender,
pub body_format: WebhookBodyFormat,
pub header_tys: WebhookHeaders,
#[derivative(Debug = "ignore")]
pub validator: Option<AppendWebhookValidator>,
}
#[derive(Clone, Debug)]
pub struct WebhookAppender {
tx: MonotonicAppender<Timestamp>,
guard: WebhookAppenderGuard,
stats: Arc<WebhookStatistics>,
}
impl WebhookAppender {
pub fn is_closed(&self) -> bool {
self.guard.is_closed()
}
pub async fn append(&self, updates: Vec<(Row, Diff)>) -> Result<(), AppendWebhookError> {
if self.is_closed() {
return Err(AppendWebhookError::ChannelClosed);
}
let count = u64::cast_from(updates.len());
self.stats
.updates_staged
.fetch_add(count, Ordering::Relaxed);
let updates = updates.into_iter().map(|update| update.into()).collect();
self.tx.append(updates).await?;
self.stats
.updates_committed
.fetch_add(count, Ordering::Relaxed);
Ok(())
}
pub fn increment_messages_received(&self, msgs: u64) {
self.stats
.messages_received
.fetch_add(msgs, Ordering::Relaxed);
}
pub fn increment_bytes_received(&self, bytes: u64) {
self.stats
.bytes_received
.fetch_add(bytes, Ordering::Relaxed);
}
pub(crate) fn new(
tx: MonotonicAppender<Timestamp>,
guard: WebhookAppenderGuard,
stats: Arc<WebhookStatistics>,
) -> Self {
WebhookAppender { tx, guard, stats }
}
}
#[derive(Clone, Debug)]
pub struct WebhookAppenderGuard {
is_closed: Arc<AtomicBool>,
}
impl WebhookAppenderGuard {
pub fn is_closed(&self) -> bool {
self.is_closed.load(Ordering::SeqCst)
}
}
#[derive(Debug)]
pub struct WebhookAppenderInvalidator {
is_closed: Arc<AtomicBool>,
}
static_assertions::assert_not_impl_all!(WebhookAppenderInvalidator: Clone);
impl WebhookAppenderInvalidator {
pub(crate) fn new() -> WebhookAppenderInvalidator {
let is_closed = Arc::new(AtomicBool::new(false));
WebhookAppenderInvalidator { is_closed }
}
pub fn guard(&self) -> WebhookAppenderGuard {
WebhookAppenderGuard {
is_closed: Arc::clone(&self.is_closed),
}
}
}
impl Drop for WebhookAppenderInvalidator {
fn drop(&mut self) {
self.is_closed.store(true, Ordering::SeqCst);
}
}
pub type WebhookAppenderName = (String, String, String);
#[derive(Debug, Clone)]
pub struct WebhookAppenderCache {
pub entries: Arc<tokio::sync::Mutex<BTreeMap<WebhookAppenderName, AppendWebhookResponse>>>,
}
impl WebhookAppenderCache {
pub fn new() -> Self {
WebhookAppenderCache {
entries: Arc::new(tokio::sync::Mutex::new(BTreeMap::new())),
}
}
}
#[derive(Debug, Clone)]
pub struct WebhookConcurrencyLimiter {
semaphore: Arc<Semaphore>,
prev_limit: usize,
}
impl WebhookConcurrencyLimiter {
pub fn new(limit: usize) -> Self {
let semaphore = Arc::new(Semaphore::new(limit));
WebhookConcurrencyLimiter {
semaphore,
prev_limit: limit,
}
}
pub fn semaphore(&self) -> Arc<Semaphore> {
Arc::clone(&self.semaphore)
}
pub fn set_limit(&mut self, new_limit: usize) {
if new_limit > self.prev_limit {
let diff = new_limit.saturating_sub(self.prev_limit);
tracing::debug!("Adding {diff} permits");
self.semaphore.add_permits(diff);
} else if new_limit < self.prev_limit {
let diff = self.prev_limit.saturating_sub(new_limit);
let diff = u32::try_from(diff).unwrap_or(u32::MAX);
tracing::debug!("Removing {diff} permits");
let semaphore = self.semaphore();
mz_ore::task::spawn(|| "webhook-concurrency-limiter-drop-permits", async move {
if let Ok(permit) = Semaphore::acquire_many_owned(semaphore, diff).await {
permit.forget()
}
});
}
self.prev_limit = new_limit;
tracing::debug!("New limit, {} permits", self.prev_limit);
}
}
impl Default for WebhookConcurrencyLimiter {
fn default() -> Self {
WebhookConcurrencyLimiter::new(mz_sql::WEBHOOK_CONCURRENCY_LIMIT)
}
}
#[cfg(test)]
mod test {
use mz_ore::assert_err;
use super::WebhookConcurrencyLimiter;
#[mz_ore::test(tokio::test)]
#[cfg_attr(miri, ignore)] async fn smoke_test_concurrency_limiter() {
let mut limiter = WebhookConcurrencyLimiter::new(10);
let semaphore_a = limiter.semaphore();
let _permit_a = semaphore_a.try_acquire_many(10).expect("acquire");
let semaphore_b = limiter.semaphore();
assert_err!(semaphore_b.try_acquire());
limiter.set_limit(15);
let _permit_b = semaphore_b.try_acquire().expect("acquire");
limiter.set_limit(5);
tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
assert_err!(semaphore_b.try_acquire());
}
}