1use std::collections::BTreeMap;
11use std::sync::Arc;
12use std::sync::atomic::{AtomicBool, Ordering};
13
14use anyhow::Context;
15use chrono::{DateTime, Utc};
16use derivative::Derivative;
17use mz_ore::cast::CastFrom;
18use mz_repr::{Datum, Diff, Row, RowArena, Timestamp};
19use mz_secrets::SecretsReader;
20use mz_secrets::cache::CachingSecretsReader;
21use mz_sql::plan::{WebhookBodyFormat, WebhookHeaders, WebhookValidation, WebhookValidationSecret};
22use mz_storage_client::controller::MonotonicAppender;
23use mz_storage_client::statistics::WebhookStatistics;
24use mz_storage_types::controller::StorageError;
25use tokio::sync::Semaphore;
26
27use crate::optimize::dataflows::{ExprPrepStyle, prep_scalar_expr};
28
29#[derive(thiserror::Error, Debug)]
31pub enum AppendWebhookError {
32 #[error("could not read a required secret")]
34 MissingSecret,
35 #[error("the provided request body is not UTF-8: {msg}")]
36 InvalidUtf8Body { msg: String },
37 #[error("the provided request body is not valid JSON: {msg}")]
38 InvalidJsonBody { msg: String },
39 #[error("webhook source '{database}.{schema}.{name}' does not exist")]
40 UnknownWebhook {
41 database: String,
42 schema: String,
43 name: String,
44 },
45 #[error("failed to validate the request")]
46 ValidationFailed,
47 #[error("validation error")]
52 ValidationError,
53 #[error("internal channel closed")]
54 ChannelClosed,
55 #[error("internal error: {0:?}")]
56 InternalError(#[from] anyhow::Error),
57 #[error("internal storage failure! {0:?}")]
58 StorageError(#[from] StorageError<mz_repr::Timestamp>),
59}
60
61#[derive(Clone)]
65pub struct AppendWebhookValidator {
66 validation: WebhookValidation,
67 secrets_reader: CachingSecretsReader,
68}
69
70impl AppendWebhookValidator {
71 pub fn new(validation: WebhookValidation, secrets_reader: CachingSecretsReader) -> Self {
72 AppendWebhookValidator {
73 validation,
74 secrets_reader,
75 }
76 }
77
78 pub async fn eval(
79 self,
80 body: bytes::Bytes,
81 headers: Arc<BTreeMap<String, String>>,
82 received_at: DateTime<Utc>,
83 ) -> Result<bool, AppendWebhookError> {
84 let AppendWebhookValidator {
85 validation,
86 secrets_reader,
87 } = self;
88
89 let WebhookValidation {
90 mut expression,
91 relation_desc: _,
92 secrets,
93 bodies: body_columns,
94 headers: header_columns,
95 } = validation;
96
97 let mut secret_contents = BTreeMap::new();
99 for WebhookValidationSecret {
100 id,
101 column_idx,
102 use_bytes,
103 } in secrets
104 {
105 let secret = secrets_reader
106 .read(id)
107 .await
108 .map_err(|_| AppendWebhookError::MissingSecret)?;
109 secret_contents.insert(column_idx, (secret, use_bytes));
110 }
111
112 prep_scalar_expr(
117 &mut expression,
118 ExprPrepStyle::WebhookValidation { now: received_at },
119 )
120 .map_err(|err| {
121 tracing::error!(?err, "failed to evaluate current time");
122 AppendWebhookError::ValidationError
123 })?;
124
125 let validate = move || {
128 let temp_storage = RowArena::default();
132 let mut datums = Vec::with_capacity(
133 body_columns.len() + header_columns.len() + secret_contents.len(),
134 );
135
136 for (column_idx, use_bytes) in body_columns {
138 assert_eq!(column_idx, datums.len(), "body index and datums mismatch!");
139
140 let datum = if use_bytes {
141 Datum::Bytes(&body[..])
142 } else {
143 let s = std::str::from_utf8(&body[..])
144 .map_err(|m| AppendWebhookError::InvalidUtf8Body { msg: m.to_string() })?;
145 Datum::String(s)
146 };
147 datums.push(datum);
148 }
149
150 let headers_byte = std::cell::OnceCell::new();
153 let headers_text = std::cell::OnceCell::new();
154 for (column_idx, use_bytes) in header_columns {
155 assert_eq!(column_idx, datums.len(), "index and datums mismatch!");
156
157 let row = if use_bytes {
158 headers_byte.get_or_init(|| {
159 let mut row = Row::with_capacity(1);
160 let mut packer = row.packer();
161 packer.push_dict(
162 headers
163 .iter()
164 .map(|(name, val)| (name.as_str(), Datum::Bytes(val.as_bytes()))),
165 );
166 row
167 })
168 } else {
169 headers_text.get_or_init(|| {
170 let mut row = Row::with_capacity(1);
171 let mut packer = row.packer();
172 packer.push_dict(
173 headers
174 .iter()
175 .map(|(name, val)| (name.as_str(), Datum::String(val))),
176 );
177 row
178 })
179 };
180 datums.push(row.unpack_first());
181 }
182
183 for column_idx in datums.len()..datums.len() + secret_contents.len() {
185 let (secret, use_bytes) = secret_contents
187 .get(&column_idx)
188 .expect("more secrets to provide, but none for the next column");
189
190 if *use_bytes {
191 datums.push(Datum::Bytes(secret));
192 } else {
193 let secret_str = std::str::from_utf8(&secret[..]).expect("valid UTF-8");
194 datums.push(Datum::String(secret_str));
195 }
196 }
197
198 let valid = expression
200 .eval(&datums[..], &temp_storage)
201 .map_err(|_| AppendWebhookError::ValidationError)?;
202 match valid {
203 Datum::True => Ok::<_, AppendWebhookError>(true),
204 Datum::False | Datum::Null => Ok(false),
205 _ => unreachable!("Creating a webhook source asserts we return a boolean"),
206 }
207 };
208
209 let valid = mz_ore::task::spawn_blocking(
211 || "webhook-validator-expr",
212 move || {
213 mz_ore::panic::catch_unwind(validate).map_err(|_| {
216 tracing::error!("panic while validating webhook request!");
217 AppendWebhookError::ValidationError
218 })
219 },
220 )
221 .await
222 .context("joining on validation")
223 .map_err(|e| {
224 tracing::error!("Failed to run validation for webhook, {e}");
225 AppendWebhookError::ValidationError
226 })??;
227
228 valid
229 }
230}
231
232#[derive(Derivative, Clone)]
233#[derivative(Debug)]
234pub struct AppendWebhookResponse {
235 pub tx: WebhookAppender,
237 pub body_format: WebhookBodyFormat,
239 pub header_tys: WebhookHeaders,
241 #[derivative(Debug = "ignore")]
243 pub validator: Option<AppendWebhookValidator>,
244}
245
246#[derive(Clone, Debug)]
249pub struct WebhookAppender {
250 tx: MonotonicAppender<Timestamp>,
251 guard: WebhookAppenderGuard,
252 stats: Arc<WebhookStatistics>,
254}
255
256impl WebhookAppender {
257 pub fn is_closed(&self) -> bool {
259 self.guard.is_closed()
260 }
261
262 pub async fn append(&self, updates: Vec<(Row, Diff)>) -> Result<(), AppendWebhookError> {
264 if self.is_closed() {
265 return Err(AppendWebhookError::ChannelClosed);
266 }
267
268 let count = u64::cast_from(updates.len());
269 self.stats
270 .updates_staged
271 .fetch_add(count, Ordering::Relaxed);
272 let updates = updates.into_iter().map(|update| update.into()).collect();
273 self.tx.append(updates).await?;
274 self.stats
275 .updates_committed
276 .fetch_add(count, Ordering::Relaxed);
277 Ok(())
278 }
279
280 pub fn increment_messages_received(&self, msgs: u64) {
283 self.stats
284 .messages_received
285 .fetch_add(msgs, Ordering::Relaxed);
286 }
287
288 pub fn increment_bytes_received(&self, bytes: u64) {
291 self.stats
292 .bytes_received
293 .fetch_add(bytes, Ordering::Relaxed);
294 }
295
296 pub(crate) fn new(
297 tx: MonotonicAppender<Timestamp>,
298 guard: WebhookAppenderGuard,
299 stats: Arc<WebhookStatistics>,
300 ) -> Self {
301 WebhookAppender { tx, guard, stats }
302 }
303}
304
305#[derive(Clone, Debug)]
310pub struct WebhookAppenderGuard {
311 is_closed: Arc<AtomicBool>,
312}
313
314impl WebhookAppenderGuard {
315 pub fn is_closed(&self) -> bool {
316 self.is_closed.load(Ordering::SeqCst)
317 }
318}
319
320#[derive(Debug)]
326pub struct WebhookAppenderInvalidator {
327 is_closed: Arc<AtomicBool>,
328}
329static_assertions::assert_not_impl_all!(WebhookAppenderInvalidator: Clone);
331
332impl WebhookAppenderInvalidator {
333 pub(crate) fn new() -> WebhookAppenderInvalidator {
334 let is_closed = Arc::new(AtomicBool::new(false));
335 WebhookAppenderInvalidator { is_closed }
336 }
337
338 pub fn guard(&self) -> WebhookAppenderGuard {
339 WebhookAppenderGuard {
340 is_closed: Arc::clone(&self.is_closed),
341 }
342 }
343}
344
345impl Drop for WebhookAppenderInvalidator {
346 fn drop(&mut self) {
347 self.is_closed.store(true, Ordering::SeqCst);
348 }
349}
350
351pub type WebhookAppenderName = (String, String, String);
352
353#[derive(Debug, Clone)]
358pub struct WebhookAppenderCache {
359 pub entries: Arc<tokio::sync::Mutex<BTreeMap<WebhookAppenderName, AppendWebhookResponse>>>,
360}
361
362impl WebhookAppenderCache {
363 pub fn new() -> Self {
364 WebhookAppenderCache {
365 entries: Arc::new(tokio::sync::Mutex::new(BTreeMap::new())),
366 }
367 }
368}
369
370#[derive(Debug, Clone)]
372pub struct WebhookConcurrencyLimiter {
373 semaphore: Arc<Semaphore>,
374 prev_limit: usize,
375}
376
377impl WebhookConcurrencyLimiter {
378 pub fn new(limit: usize) -> Self {
379 let semaphore = Arc::new(Semaphore::new(limit));
380
381 WebhookConcurrencyLimiter {
382 semaphore,
383 prev_limit: limit,
384 }
385 }
386
387 pub fn semaphore(&self) -> Arc<Semaphore> {
389 Arc::clone(&self.semaphore)
390 }
391
392 pub fn set_limit(&mut self, new_limit: usize) {
394 if new_limit > self.prev_limit {
395 let diff = new_limit.saturating_sub(self.prev_limit);
397 tracing::debug!("Adding {diff} permits");
398
399 self.semaphore.add_permits(diff);
400 } else if new_limit < self.prev_limit {
401 let diff = self.prev_limit.saturating_sub(new_limit);
403 let diff = u32::try_from(diff).unwrap_or(u32::MAX);
404 tracing::debug!("Removing {diff} permits");
405
406 let semaphore = self.semaphore();
407
408 mz_ore::task::spawn(|| "webhook-concurrency-limiter-drop-permits", async move {
411 if let Ok(permit) = Semaphore::acquire_many_owned(semaphore, diff).await {
412 permit.forget()
413 }
414 });
415 }
416
417 self.prev_limit = new_limit;
419 tracing::debug!("New limit, {} permits", self.prev_limit);
420 }
421}
422
423impl Default for WebhookConcurrencyLimiter {
424 fn default() -> Self {
425 WebhookConcurrencyLimiter::new(mz_sql::WEBHOOK_CONCURRENCY_LIMIT)
426 }
427}
428
429#[cfg(test)]
430mod test {
431 use mz_ore::assert_err;
432
433 use super::WebhookConcurrencyLimiter;
434
435 #[mz_ore::test(tokio::test)]
436 #[cfg_attr(miri, ignore)] async fn smoke_test_concurrency_limiter() {
438 let mut limiter = WebhookConcurrencyLimiter::new(10);
439
440 let semaphore_a = limiter.semaphore();
441 let _permit_a = semaphore_a.try_acquire_many(10).expect("acquire");
442
443 let semaphore_b = limiter.semaphore();
444 assert_err!(semaphore_b.try_acquire());
445
446 limiter.set_limit(15);
448
449 let _permit_b = semaphore_b.try_acquire().expect("acquire");
451
452 limiter.set_limit(5);
454
455 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
456
457 assert_err!(semaphore_b.try_acquire());
459 }
460}