1use std::collections::BTreeMap;
11use std::sync::Arc;
12use std::sync::atomic::{AtomicBool, Ordering};
13
14use anyhow::Context;
15use chrono::{DateTime, Utc};
16use derivative::Derivative;
17use mz_expr::Eval;
18use mz_ore::cast::CastFrom;
19use mz_repr::{Datum, Diff, Row, RowArena};
20use mz_secrets::SecretsReader;
21use mz_secrets::cache::CachingSecretsReader;
22use mz_sql::plan::{WebhookBodyFormat, WebhookHeaders, WebhookValidation, WebhookValidationSecret};
23use mz_storage_client::controller::MonotonicAppender;
24use mz_storage_client::statistics::WebhookStatistics;
25use mz_storage_types::controller::StorageError;
26use tokio::sync::Semaphore;
27
28use crate::optimize::dataflows::{ExprPrep, ExprPrepWebhookValidation};
29
30#[derive(thiserror::Error, Debug)]
32pub enum AppendWebhookError {
33 #[error("could not read a required secret")]
35 MissingSecret,
36 #[error("the provided request body is not UTF-8: {msg}")]
37 InvalidUtf8Body { msg: String },
38 #[error("the provided request body is not valid JSON: {msg}")]
39 InvalidJsonBody { msg: String },
40 #[error("webhook source '{database}.{schema}.{name}' does not exist")]
41 UnknownWebhook {
42 database: String,
43 schema: String,
44 name: String,
45 },
46 #[error("failed to validate the request")]
47 ValidationFailed,
48 #[error("validation error")]
53 ValidationError,
54 #[error("internal channel closed")]
55 ChannelClosed,
56 #[error("internal error: {0:?}")]
57 InternalError(#[from] anyhow::Error),
58 #[error("internal storage failure! {0:?}")]
59 StorageError(#[from] StorageError),
60}
61
62#[derive(Clone)]
66pub struct AppendWebhookValidator {
67 validation: WebhookValidation,
68 secrets_reader: CachingSecretsReader,
69}
70
71impl AppendWebhookValidator {
72 pub fn new(validation: WebhookValidation, secrets_reader: CachingSecretsReader) -> Self {
73 AppendWebhookValidator {
74 validation,
75 secrets_reader,
76 }
77 }
78
79 pub async fn eval(
80 self,
81 body: bytes::Bytes,
82 headers: Arc<BTreeMap<String, String>>,
83 received_at: DateTime<Utc>,
84 ) -> Result<bool, AppendWebhookError> {
85 let AppendWebhookValidator {
86 validation,
87 secrets_reader,
88 } = self;
89
90 let WebhookValidation {
91 mut expression,
92 relation_desc: _,
93 secrets,
94 bodies: body_columns,
95 headers: header_columns,
96 } = validation;
97
98 let mut secret_contents = BTreeMap::new();
100 for WebhookValidationSecret {
101 id,
102 column_idx,
103 use_bytes,
104 } in secrets
105 {
106 let secret = secrets_reader
107 .read(id)
108 .await
109 .map_err(|_| AppendWebhookError::MissingSecret)?;
110 secret_contents.insert(column_idx, (secret, use_bytes));
111 }
112
113 ExprPrepWebhookValidation { now: received_at }
118 .prep_scalar_expr(&mut expression)
119 .map_err(|err| {
120 tracing::error!(?err, "failed to evaluate current time");
121 AppendWebhookError::ValidationError
122 })?;
123
124 let validate = move || {
127 let temp_storage = RowArena::default();
131 let mut datums = Vec::with_capacity(
132 body_columns.len() + header_columns.len() + secret_contents.len(),
133 );
134
135 for (column_idx, use_bytes) in body_columns {
137 assert_eq!(column_idx, datums.len(), "body index and datums mismatch!");
138
139 let datum = if use_bytes {
140 Datum::Bytes(&body[..])
141 } else {
142 let s = std::str::from_utf8(&body[..])
143 .map_err(|m| AppendWebhookError::InvalidUtf8Body { msg: m.to_string() })?;
144 Datum::String(s)
145 };
146 datums.push(datum);
147 }
148
149 let headers_byte = std::cell::OnceCell::new();
152 let headers_text = std::cell::OnceCell::new();
153 for (column_idx, use_bytes) in header_columns {
154 assert_eq!(column_idx, datums.len(), "index and datums mismatch!");
155
156 let row = if use_bytes {
157 headers_byte.get_or_init(|| {
158 let mut row = Row::with_capacity(1);
159 let mut packer = row.packer();
160 packer.push_dict(
161 headers
162 .iter()
163 .map(|(name, val)| (name.as_str(), Datum::Bytes(val.as_bytes()))),
164 );
165 row
166 })
167 } else {
168 headers_text.get_or_init(|| {
169 let mut row = Row::with_capacity(1);
170 let mut packer = row.packer();
171 packer.push_dict(
172 headers
173 .iter()
174 .map(|(name, val)| (name.as_str(), Datum::String(val))),
175 );
176 row
177 })
178 };
179 datums.push(row.unpack_first());
180 }
181
182 for column_idx in datums.len()..datums.len() + secret_contents.len() {
184 let (secret, use_bytes) = secret_contents
186 .get(&column_idx)
187 .expect("more secrets to provide, but none for the next column");
188
189 if *use_bytes {
190 datums.push(Datum::Bytes(secret));
191 } else {
192 let secret_str = std::str::from_utf8(&secret[..]).expect("valid UTF-8");
193 datums.push(Datum::String(secret_str));
194 }
195 }
196
197 let valid = expression
199 .eval(&datums[..], &temp_storage)
200 .map_err(|_| AppendWebhookError::ValidationError)?;
201 match valid {
202 Datum::True => Ok::<_, AppendWebhookError>(true),
203 Datum::False | Datum::Null => Ok(false),
204 _ => unreachable!("Creating a webhook source asserts we return a boolean"),
205 }
206 };
207
208 let valid = mz_ore::task::spawn_blocking(
210 || "webhook-validator-expr",
211 move || {
212 mz_ore::panic::catch_unwind(validate).map_err(|_| {
215 tracing::error!("panic while validating webhook request!");
216 AppendWebhookError::ValidationError
217 })
218 },
219 )
220 .await
221 .context("joining on validation")
222 .map_err(|e| {
223 tracing::error!("Failed to run validation for webhook, {e}");
224 AppendWebhookError::ValidationError
225 })?;
226
227 valid
228 }
229}
230
231#[derive(Derivative, Clone)]
232#[derivative(Debug)]
233pub struct AppendWebhookResponse {
234 pub tx: WebhookAppender,
236 pub body_format: WebhookBodyFormat,
238 pub header_tys: WebhookHeaders,
240 #[derivative(Debug = "ignore")]
242 pub validator: Option<AppendWebhookValidator>,
243}
244
245#[derive(Clone, Debug)]
248pub struct WebhookAppender {
249 tx: MonotonicAppender,
250 guard: WebhookAppenderGuard,
251 stats: Arc<WebhookStatistics>,
253}
254
255impl WebhookAppender {
256 pub fn is_closed(&self) -> bool {
258 self.guard.is_closed()
259 }
260
261 pub async fn append(&self, updates: Vec<(Row, Diff)>) -> Result<(), AppendWebhookError> {
263 if self.is_closed() {
264 return Err(AppendWebhookError::ChannelClosed);
265 }
266
267 let count = u64::cast_from(updates.len());
268 self.stats
269 .updates_staged
270 .fetch_add(count, Ordering::Relaxed);
271 let updates = updates.into_iter().map(|update| update.into()).collect();
272 self.tx.append(updates).await?;
273 self.stats
274 .updates_committed
275 .fetch_add(count, Ordering::Relaxed);
276 Ok(())
277 }
278
279 pub fn increment_messages_received(&self, msgs: u64) {
282 self.stats
283 .messages_received
284 .fetch_add(msgs, Ordering::Relaxed);
285 }
286
287 pub fn increment_bytes_received(&self, bytes: u64) {
290 self.stats
291 .bytes_received
292 .fetch_add(bytes, Ordering::Relaxed);
293 }
294
295 pub(crate) fn new(
296 tx: MonotonicAppender,
297 guard: WebhookAppenderGuard,
298 stats: Arc<WebhookStatistics>,
299 ) -> Self {
300 WebhookAppender { tx, guard, stats }
301 }
302}
303
304#[derive(Clone, Debug)]
309pub struct WebhookAppenderGuard {
310 is_closed: Arc<AtomicBool>,
311}
312
313impl WebhookAppenderGuard {
314 pub fn is_closed(&self) -> bool {
315 self.is_closed.load(Ordering::SeqCst)
316 }
317}
318
319#[derive(Debug)]
325pub struct WebhookAppenderInvalidator {
326 is_closed: Arc<AtomicBool>,
327}
328static_assertions::assert_not_impl_all!(WebhookAppenderInvalidator: Clone);
330
331impl WebhookAppenderInvalidator {
332 pub(crate) fn new() -> WebhookAppenderInvalidator {
333 let is_closed = Arc::new(AtomicBool::new(false));
334 WebhookAppenderInvalidator { is_closed }
335 }
336
337 pub fn guard(&self) -> WebhookAppenderGuard {
338 WebhookAppenderGuard {
339 is_closed: Arc::clone(&self.is_closed),
340 }
341 }
342}
343
344impl Drop for WebhookAppenderInvalidator {
345 fn drop(&mut self) {
346 self.is_closed.store(true, Ordering::SeqCst);
347 }
348}
349
350pub type WebhookAppenderName = (String, String, String);
351
352#[derive(Debug, Clone)]
357pub struct WebhookAppenderCache {
358 pub entries: Arc<tokio::sync::Mutex<BTreeMap<WebhookAppenderName, AppendWebhookResponse>>>,
359}
360
361impl WebhookAppenderCache {
362 pub fn new() -> Self {
363 WebhookAppenderCache {
364 entries: Arc::new(tokio::sync::Mutex::new(BTreeMap::new())),
365 }
366 }
367}
368
369#[derive(Debug, Clone)]
371pub struct WebhookConcurrencyLimiter {
372 semaphore: Arc<Semaphore>,
373 prev_limit: usize,
374}
375
376impl WebhookConcurrencyLimiter {
377 pub fn new(limit: usize) -> Self {
378 let semaphore = Arc::new(Semaphore::new(limit));
379
380 WebhookConcurrencyLimiter {
381 semaphore,
382 prev_limit: limit,
383 }
384 }
385
386 pub fn semaphore(&self) -> Arc<Semaphore> {
388 Arc::clone(&self.semaphore)
389 }
390
391 pub fn set_limit(&mut self, new_limit: usize) {
393 if new_limit > self.prev_limit {
394 let diff = new_limit.saturating_sub(self.prev_limit);
396 tracing::debug!("Adding {diff} permits");
397
398 self.semaphore.add_permits(diff);
399 } else if new_limit < self.prev_limit {
400 let diff = self.prev_limit.saturating_sub(new_limit);
402 let diff = u32::try_from(diff).unwrap_or(u32::MAX);
403 tracing::debug!("Removing {diff} permits");
404
405 let semaphore = self.semaphore();
406
407 mz_ore::task::spawn(|| "webhook-concurrency-limiter-drop-permits", async move {
410 if let Ok(permit) = Semaphore::acquire_many_owned(semaphore, diff).await {
411 permit.forget()
412 }
413 });
414 }
415
416 self.prev_limit = new_limit;
418 tracing::debug!("New limit, {} permits", self.prev_limit);
419 }
420}
421
422impl Default for WebhookConcurrencyLimiter {
423 fn default() -> Self {
424 WebhookConcurrencyLimiter::new(mz_sql::WEBHOOK_CONCURRENCY_LIMIT)
425 }
426}
427
428#[cfg(test)]
429mod test {
430 use mz_ore::assert_err;
431
432 use super::WebhookConcurrencyLimiter;
433
434 #[mz_ore::test(tokio::test)]
435 #[cfg_attr(miri, ignore)] async fn smoke_test_concurrency_limiter() {
437 let mut limiter = WebhookConcurrencyLimiter::new(10);
438
439 let semaphore_a = limiter.semaphore();
440 let _permit_a = semaphore_a.try_acquire_many(10).expect("acquire");
441
442 let semaphore_b = limiter.semaphore();
443 assert_err!(semaphore_b.try_acquire());
444
445 limiter.set_limit(15);
447
448 let _permit_b = semaphore_b.try_acquire().expect("acquire");
450
451 limiter.set_limit(5);
453
454 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
455
456 assert_err!(semaphore_b.try_acquire());
458 }
459}