1use std::collections::BTreeMap;
11use std::sync::Arc;
12use std::sync::atomic::{AtomicBool, Ordering};
13
14use anyhow::Context;
15use chrono::{DateTime, Utc};
16use derivative::Derivative;
17use mz_ore::cast::CastFrom;
18use mz_repr::{Datum, Diff, Row, RowArena, Timestamp};
19use mz_secrets::SecretsReader;
20use mz_secrets::cache::CachingSecretsReader;
21use mz_sql::plan::{WebhookBodyFormat, WebhookHeaders, WebhookValidation, WebhookValidationSecret};
22use mz_storage_client::controller::MonotonicAppender;
23use mz_storage_client::statistics::WebhookStatistics;
24use mz_storage_types::controller::StorageError;
25use tokio::sync::Semaphore;
26
27use crate::optimize::dataflows::{ExprPrep, ExprPrepWebhookValidation};
28
29#[derive(thiserror::Error, Debug)]
31pub enum AppendWebhookError {
32 #[error("could not read a required secret")]
34 MissingSecret,
35 #[error("the provided request body is not UTF-8: {msg}")]
36 InvalidUtf8Body { msg: String },
37 #[error("the provided request body is not valid JSON: {msg}")]
38 InvalidJsonBody { msg: String },
39 #[error("webhook source '{database}.{schema}.{name}' does not exist")]
40 UnknownWebhook {
41 database: String,
42 schema: String,
43 name: String,
44 },
45 #[error("failed to validate the request")]
46 ValidationFailed,
47 #[error("validation error")]
52 ValidationError,
53 #[error("internal channel closed")]
54 ChannelClosed,
55 #[error("internal error: {0:?}")]
56 InternalError(#[from] anyhow::Error),
57 #[error("internal storage failure! {0:?}")]
58 StorageError(#[from] StorageError<mz_repr::Timestamp>),
59}
60
61#[derive(Clone)]
65pub struct AppendWebhookValidator {
66 validation: WebhookValidation,
67 secrets_reader: CachingSecretsReader,
68}
69
70impl AppendWebhookValidator {
71 pub fn new(validation: WebhookValidation, secrets_reader: CachingSecretsReader) -> Self {
72 AppendWebhookValidator {
73 validation,
74 secrets_reader,
75 }
76 }
77
78 pub async fn eval(
79 self,
80 body: bytes::Bytes,
81 headers: Arc<BTreeMap<String, String>>,
82 received_at: DateTime<Utc>,
83 ) -> Result<bool, AppendWebhookError> {
84 let AppendWebhookValidator {
85 validation,
86 secrets_reader,
87 } = self;
88
89 let WebhookValidation {
90 mut expression,
91 relation_desc: _,
92 secrets,
93 bodies: body_columns,
94 headers: header_columns,
95 } = validation;
96
97 let mut secret_contents = BTreeMap::new();
99 for WebhookValidationSecret {
100 id,
101 column_idx,
102 use_bytes,
103 } in secrets
104 {
105 let secret = secrets_reader
106 .read(id)
107 .await
108 .map_err(|_| AppendWebhookError::MissingSecret)?;
109 secret_contents.insert(column_idx, (secret, use_bytes));
110 }
111
112 ExprPrepWebhookValidation { now: received_at }
117 .prep_scalar_expr(&mut expression)
118 .map_err(|err| {
119 tracing::error!(?err, "failed to evaluate current time");
120 AppendWebhookError::ValidationError
121 })?;
122
123 let validate = move || {
126 let temp_storage = RowArena::default();
130 let mut datums = Vec::with_capacity(
131 body_columns.len() + header_columns.len() + secret_contents.len(),
132 );
133
134 for (column_idx, use_bytes) in body_columns {
136 assert_eq!(column_idx, datums.len(), "body index and datums mismatch!");
137
138 let datum = if use_bytes {
139 Datum::Bytes(&body[..])
140 } else {
141 let s = std::str::from_utf8(&body[..])
142 .map_err(|m| AppendWebhookError::InvalidUtf8Body { msg: m.to_string() })?;
143 Datum::String(s)
144 };
145 datums.push(datum);
146 }
147
148 let headers_byte = std::cell::OnceCell::new();
151 let headers_text = std::cell::OnceCell::new();
152 for (column_idx, use_bytes) in header_columns {
153 assert_eq!(column_idx, datums.len(), "index and datums mismatch!");
154
155 let row = if use_bytes {
156 headers_byte.get_or_init(|| {
157 let mut row = Row::with_capacity(1);
158 let mut packer = row.packer();
159 packer.push_dict(
160 headers
161 .iter()
162 .map(|(name, val)| (name.as_str(), Datum::Bytes(val.as_bytes()))),
163 );
164 row
165 })
166 } else {
167 headers_text.get_or_init(|| {
168 let mut row = Row::with_capacity(1);
169 let mut packer = row.packer();
170 packer.push_dict(
171 headers
172 .iter()
173 .map(|(name, val)| (name.as_str(), Datum::String(val))),
174 );
175 row
176 })
177 };
178 datums.push(row.unpack_first());
179 }
180
181 for column_idx in datums.len()..datums.len() + secret_contents.len() {
183 let (secret, use_bytes) = secret_contents
185 .get(&column_idx)
186 .expect("more secrets to provide, but none for the next column");
187
188 if *use_bytes {
189 datums.push(Datum::Bytes(secret));
190 } else {
191 let secret_str = std::str::from_utf8(&secret[..]).expect("valid UTF-8");
192 datums.push(Datum::String(secret_str));
193 }
194 }
195
196 let valid = expression
198 .eval(&datums[..], &temp_storage)
199 .map_err(|_| AppendWebhookError::ValidationError)?;
200 match valid {
201 Datum::True => Ok::<_, AppendWebhookError>(true),
202 Datum::False | Datum::Null => Ok(false),
203 _ => unreachable!("Creating a webhook source asserts we return a boolean"),
204 }
205 };
206
207 let valid = mz_ore::task::spawn_blocking(
209 || "webhook-validator-expr",
210 move || {
211 mz_ore::panic::catch_unwind(validate).map_err(|_| {
214 tracing::error!("panic while validating webhook request!");
215 AppendWebhookError::ValidationError
216 })
217 },
218 )
219 .await
220 .context("joining on validation")
221 .map_err(|e| {
222 tracing::error!("Failed to run validation for webhook, {e}");
223 AppendWebhookError::ValidationError
224 })?;
225
226 valid
227 }
228}
229
230#[derive(Derivative, Clone)]
231#[derivative(Debug)]
232pub struct AppendWebhookResponse {
233 pub tx: WebhookAppender,
235 pub body_format: WebhookBodyFormat,
237 pub header_tys: WebhookHeaders,
239 #[derivative(Debug = "ignore")]
241 pub validator: Option<AppendWebhookValidator>,
242}
243
244#[derive(Clone, Debug)]
247pub struct WebhookAppender {
248 tx: MonotonicAppender<Timestamp>,
249 guard: WebhookAppenderGuard,
250 stats: Arc<WebhookStatistics>,
252}
253
254impl WebhookAppender {
255 pub fn is_closed(&self) -> bool {
257 self.guard.is_closed()
258 }
259
260 pub async fn append(&self, updates: Vec<(Row, Diff)>) -> Result<(), AppendWebhookError> {
262 if self.is_closed() {
263 return Err(AppendWebhookError::ChannelClosed);
264 }
265
266 let count = u64::cast_from(updates.len());
267 self.stats
268 .updates_staged
269 .fetch_add(count, Ordering::Relaxed);
270 let updates = updates.into_iter().map(|update| update.into()).collect();
271 self.tx.append(updates).await?;
272 self.stats
273 .updates_committed
274 .fetch_add(count, Ordering::Relaxed);
275 Ok(())
276 }
277
278 pub fn increment_messages_received(&self, msgs: u64) {
281 self.stats
282 .messages_received
283 .fetch_add(msgs, Ordering::Relaxed);
284 }
285
286 pub fn increment_bytes_received(&self, bytes: u64) {
289 self.stats
290 .bytes_received
291 .fetch_add(bytes, Ordering::Relaxed);
292 }
293
294 pub(crate) fn new(
295 tx: MonotonicAppender<Timestamp>,
296 guard: WebhookAppenderGuard,
297 stats: Arc<WebhookStatistics>,
298 ) -> Self {
299 WebhookAppender { tx, guard, stats }
300 }
301}
302
303#[derive(Clone, Debug)]
308pub struct WebhookAppenderGuard {
309 is_closed: Arc<AtomicBool>,
310}
311
312impl WebhookAppenderGuard {
313 pub fn is_closed(&self) -> bool {
314 self.is_closed.load(Ordering::SeqCst)
315 }
316}
317
318#[derive(Debug)]
324pub struct WebhookAppenderInvalidator {
325 is_closed: Arc<AtomicBool>,
326}
327static_assertions::assert_not_impl_all!(WebhookAppenderInvalidator: Clone);
329
330impl WebhookAppenderInvalidator {
331 pub(crate) fn new() -> WebhookAppenderInvalidator {
332 let is_closed = Arc::new(AtomicBool::new(false));
333 WebhookAppenderInvalidator { is_closed }
334 }
335
336 pub fn guard(&self) -> WebhookAppenderGuard {
337 WebhookAppenderGuard {
338 is_closed: Arc::clone(&self.is_closed),
339 }
340 }
341}
342
343impl Drop for WebhookAppenderInvalidator {
344 fn drop(&mut self) {
345 self.is_closed.store(true, Ordering::SeqCst);
346 }
347}
348
349pub type WebhookAppenderName = (String, String, String);
350
351#[derive(Debug, Clone)]
356pub struct WebhookAppenderCache {
357 pub entries: Arc<tokio::sync::Mutex<BTreeMap<WebhookAppenderName, AppendWebhookResponse>>>,
358}
359
360impl WebhookAppenderCache {
361 pub fn new() -> Self {
362 WebhookAppenderCache {
363 entries: Arc::new(tokio::sync::Mutex::new(BTreeMap::new())),
364 }
365 }
366}
367
368#[derive(Debug, Clone)]
370pub struct WebhookConcurrencyLimiter {
371 semaphore: Arc<Semaphore>,
372 prev_limit: usize,
373}
374
375impl WebhookConcurrencyLimiter {
376 pub fn new(limit: usize) -> Self {
377 let semaphore = Arc::new(Semaphore::new(limit));
378
379 WebhookConcurrencyLimiter {
380 semaphore,
381 prev_limit: limit,
382 }
383 }
384
385 pub fn semaphore(&self) -> Arc<Semaphore> {
387 Arc::clone(&self.semaphore)
388 }
389
390 pub fn set_limit(&mut self, new_limit: usize) {
392 if new_limit > self.prev_limit {
393 let diff = new_limit.saturating_sub(self.prev_limit);
395 tracing::debug!("Adding {diff} permits");
396
397 self.semaphore.add_permits(diff);
398 } else if new_limit < self.prev_limit {
399 let diff = self.prev_limit.saturating_sub(new_limit);
401 let diff = u32::try_from(diff).unwrap_or(u32::MAX);
402 tracing::debug!("Removing {diff} permits");
403
404 let semaphore = self.semaphore();
405
406 mz_ore::task::spawn(|| "webhook-concurrency-limiter-drop-permits", async move {
409 if let Ok(permit) = Semaphore::acquire_many_owned(semaphore, diff).await {
410 permit.forget()
411 }
412 });
413 }
414
415 self.prev_limit = new_limit;
417 tracing::debug!("New limit, {} permits", self.prev_limit);
418 }
419}
420
421impl Default for WebhookConcurrencyLimiter {
422 fn default() -> Self {
423 WebhookConcurrencyLimiter::new(mz_sql::WEBHOOK_CONCURRENCY_LIMIT)
424 }
425}
426
427#[cfg(test)]
428mod test {
429 use mz_ore::assert_err;
430
431 use super::WebhookConcurrencyLimiter;
432
433 #[mz_ore::test(tokio::test)]
434 #[cfg_attr(miri, ignore)] async fn smoke_test_concurrency_limiter() {
436 let mut limiter = WebhookConcurrencyLimiter::new(10);
437
438 let semaphore_a = limiter.semaphore();
439 let _permit_a = semaphore_a.try_acquire_many(10).expect("acquire");
440
441 let semaphore_b = limiter.semaphore();
442 assert_err!(semaphore_b.try_acquire());
443
444 limiter.set_limit(15);
446
447 let _permit_b = semaphore_b.try_acquire().expect("acquire");
449
450 limiter.set_limit(5);
452
453 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
454
455 assert_err!(semaphore_b.try_acquire());
457 }
458}