1#![allow(clippy::as_conversions, clippy::clone_on_ref_ptr)]
13
14use std::collections::{BTreeMap, BTreeSet};
17use std::mem;
18
19use async_trait::async_trait;
20use bytesize::ByteSize;
21use differential_dataflow::consolidation::consolidate_updates;
22use differential_dataflow::lattice::Lattice;
23use mz_expr::row::RowCollection;
24use mz_ore::cast::CastInto;
25use mz_ore::soft_panic_or_log;
26use mz_ore::tracing::OpenTelemetryContext;
27use mz_repr::{Diff, GlobalId, Row};
28use mz_service::client::{GenericClient, Partitionable, PartitionedState};
29use timely::PartialOrder;
30use timely::progress::frontier::{Antichain, MutableAntichain};
31use uuid::Uuid;
32
33use crate::controller::ComputeControllerTimestamp;
34use crate::protocol::command::ComputeCommand;
35use crate::protocol::response::{
36 ComputeResponse, CopyToResponse, FrontiersResponse, PeekResponse, StashedPeekResponse,
37 SubscribeBatch, SubscribeResponse,
38};
39
40pub trait ComputeClient<T = mz_repr::Timestamp>:
42 GenericClient<ComputeCommand<T>, ComputeResponse<T>>
43{
44}
45
46impl<C, T> ComputeClient<T> for C where C: GenericClient<ComputeCommand<T>, ComputeResponse<T>> {}
47
48#[async_trait]
49impl<T: Send> GenericClient<ComputeCommand<T>, ComputeResponse<T>> for Box<dyn ComputeClient<T>> {
50 async fn send(&mut self, cmd: ComputeCommand<T>) -> Result<(), anyhow::Error> {
51 (**self).send(cmd).await
52 }
53
54 async fn recv(&mut self) -> Result<Option<ComputeResponse<T>>, anyhow::Error> {
60 (**self).recv().await
62 }
63}
64
65#[derive(Debug)]
88pub struct PartitionedComputeState<T> {
89 parts: usize,
91 max_result_size: u64,
95 frontiers: BTreeMap<GlobalId, TrackedFrontiers<T>>,
108 peek_responses: BTreeMap<Uuid, (PeekResponse, BTreeSet<usize>)>,
119 copy_to_responses: BTreeMap<GlobalId, (CopyToResponse, BTreeSet<usize>)>,
130 pending_subscribes: BTreeMap<GlobalId, PendingSubscribe<T>>,
154}
155
156impl<T> Partitionable<ComputeCommand<T>, ComputeResponse<T>>
157 for (ComputeCommand<T>, ComputeResponse<T>)
158where
159 T: ComputeControllerTimestamp,
160{
161 type PartitionedState = PartitionedComputeState<T>;
162
163 fn new(parts: usize) -> PartitionedComputeState<T> {
164 PartitionedComputeState {
165 parts,
166 max_result_size: u64::MAX,
167 frontiers: BTreeMap::new(),
168 peek_responses: BTreeMap::new(),
169 pending_subscribes: BTreeMap::new(),
170 copy_to_responses: BTreeMap::new(),
171 }
172 }
173}
174
175impl<T> PartitionedComputeState<T>
176where
177 T: ComputeControllerTimestamp,
178{
179 pub fn observe_command(&mut self, command: &ComputeCommand<T>) {
181 match command {
182 ComputeCommand::UpdateConfiguration(config) => {
183 if let Some(max_result_size) = config.max_result_size {
184 self.max_result_size = max_result_size;
185 }
186 }
187 _ => {
188 }
191 }
192 }
193
194 fn absorb_frontiers(
196 &mut self,
197 shard_id: usize,
198 collection_id: GlobalId,
199 frontiers: FrontiersResponse<T>,
200 ) -> Option<ComputeResponse<T>> {
201 let tracked = self
202 .frontiers
203 .entry(collection_id)
204 .or_insert_with(|| TrackedFrontiers::new(self.parts));
205
206 let write_frontier = frontiers
207 .write_frontier
208 .and_then(|f| tracked.update_write_frontier(shard_id, &f));
209 let input_frontier = frontiers
210 .input_frontier
211 .and_then(|f| tracked.update_input_frontier(shard_id, &f));
212 let output_frontier = frontiers
213 .output_frontier
214 .and_then(|f| tracked.update_output_frontier(shard_id, &f));
215
216 let frontiers = FrontiersResponse {
217 write_frontier,
218 input_frontier,
219 output_frontier,
220 };
221 let result = frontiers
222 .has_updates()
223 .then_some(ComputeResponse::Frontiers(collection_id, frontiers));
224
225 if tracked.all_empty() {
226 self.frontiers.remove(&collection_id);
229 }
230
231 result
232 }
233
234 fn absorb_peek_response(
236 &mut self,
237 shard_id: usize,
238 uuid: Uuid,
239 response: PeekResponse,
240 otel_ctx: OpenTelemetryContext,
241 ) -> Option<ComputeResponse<T>> {
242 let (merged, ready_shards) = self.peek_responses.entry(uuid).or_insert((
243 PeekResponse::Rows(RowCollection::default()),
244 BTreeSet::new(),
245 ));
246
247 let first = ready_shards.insert(shard_id);
248 assert!(first, "duplicate peek response");
249
250 let resp1 = mem::replace(merged, PeekResponse::Canceled);
251 *merged = merge_peek_responses(resp1, response, self.max_result_size);
252
253 if ready_shards.len() == self.parts {
254 let (response, _) = self.peek_responses.remove(&uuid).unwrap();
255 Some(ComputeResponse::PeekResponse(uuid, response, otel_ctx))
256 } else {
257 None
258 }
259 }
260
261 fn absorb_copy_to_response(
263 &mut self,
264 shard_id: usize,
265 copyto_id: GlobalId,
266 response: CopyToResponse,
267 ) -> Option<ComputeResponse<T>> {
268 use CopyToResponse::*;
269
270 let (merged, ready_shards) = self
271 .copy_to_responses
272 .entry(copyto_id)
273 .or_insert((CopyToResponse::RowCount(0), BTreeSet::new()));
274
275 let first = ready_shards.insert(shard_id);
276 assert!(first, "duplicate copy-to response");
277
278 let resp1 = mem::replace(merged, Dropped);
279 *merged = match (resp1, response) {
280 (Dropped, _) | (_, Dropped) => Dropped,
281 (Error(e), _) | (_, Error(e)) => Error(e),
282 (RowCount(r1), RowCount(r2)) => RowCount(r1 + r2),
283 };
284
285 if ready_shards.len() == self.parts {
286 let (response, _) = self.copy_to_responses.remove(©to_id).unwrap();
287 Some(ComputeResponse::CopyToResponse(copyto_id, response))
288 } else {
289 None
290 }
291 }
292
293 fn absorb_subscribe_response(
295 &mut self,
296 subscribe_id: GlobalId,
297 response: SubscribeResponse<T>,
298 ) -> Option<ComputeResponse<T>> {
299 let tracked = self
300 .pending_subscribes
301 .entry(subscribe_id)
302 .or_insert_with(|| PendingSubscribe::new(self.parts));
303
304 let emit_response = match response {
305 SubscribeResponse::Batch(batch) => {
306 let frontiers = &mut tracked.frontiers;
307 let old_frontier = frontiers.frontier().to_owned();
308 frontiers.update_iter(batch.lower.into_iter().map(|t| (t, -1)));
309 frontiers.update_iter(batch.upper.into_iter().map(|t| (t, 1)));
310 let new_frontier = frontiers.frontier().to_owned();
311
312 tracked.stash(batch.updates, self.max_result_size);
313
314 if old_frontier != new_frontier && !tracked.dropped {
318 let updates = match &mut tracked.stashed_updates {
319 Ok(stashed_updates) => {
320 consolidate_updates(stashed_updates);
323
324 let mut ship = Vec::new();
325 let mut keep = Vec::new();
326 for (time, data, diff) in stashed_updates.drain(..) {
327 if new_frontier.less_equal(&time) {
328 keep.push((time, data, diff));
329 } else {
330 ship.push((time, data, diff));
331 }
332 }
333 tracked.stashed_updates = Ok(keep);
334 Ok(ship)
335 }
336 Err(text) => Err(text.clone()),
337 };
338 Some(ComputeResponse::SubscribeResponse(
339 subscribe_id,
340 SubscribeResponse::Batch(SubscribeBatch {
341 lower: old_frontier,
342 upper: new_frontier,
343 updates,
344 }),
345 ))
346 } else {
347 None
348 }
349 }
350 SubscribeResponse::DroppedAt(frontier) => {
351 tracked
352 .frontiers
353 .update_iter(frontier.iter().map(|t| (t.clone(), -1)));
354
355 if tracked.dropped {
356 None
357 } else {
358 tracked.dropped = true;
359 Some(ComputeResponse::SubscribeResponse(
360 subscribe_id,
361 SubscribeResponse::DroppedAt(frontier),
362 ))
363 }
364 }
365 };
366
367 if tracked.frontiers.frontier().is_empty() {
368 self.pending_subscribes.remove(&subscribe_id);
371 }
372
373 emit_response
374 }
375}
376
377impl<T> PartitionedState<ComputeCommand<T>, ComputeResponse<T>> for PartitionedComputeState<T>
378where
379 T: ComputeControllerTimestamp,
380{
381 fn split_command(&mut self, command: ComputeCommand<T>) -> Vec<Option<ComputeCommand<T>>> {
382 self.observe_command(&command);
383
384 match command {
388 command @ ComputeCommand::Hello { .. }
389 | command @ ComputeCommand::UpdateConfiguration(_) => {
390 vec![Some(command); self.parts]
391 }
392 command => {
393 let mut r = vec![None; self.parts];
394 r[0] = Some(command);
395 r
396 }
397 }
398 }
399
400 fn absorb_response(
401 &mut self,
402 shard_id: usize,
403 message: ComputeResponse<T>,
404 ) -> Option<Result<ComputeResponse<T>, anyhow::Error>> {
405 let response = match message {
406 ComputeResponse::Frontiers(id, frontiers) => {
407 self.absorb_frontiers(shard_id, id, frontiers)
408 }
409 ComputeResponse::PeekResponse(uuid, response, otel_ctx) => {
410 self.absorb_peek_response(shard_id, uuid, response, otel_ctx)
411 }
412 ComputeResponse::SubscribeResponse(id, response) => {
413 self.absorb_subscribe_response(id, response)
414 }
415 ComputeResponse::CopyToResponse(id, response) => {
416 self.absorb_copy_to_response(shard_id, id, response)
417 }
418 response @ ComputeResponse::Status(_) => {
419 Some(response)
421 }
422 };
423
424 response.map(Ok)
425 }
426}
427
428#[derive(Debug)]
433struct TrackedFrontiers<T> {
434 write_frontier: (MutableAntichain<T>, Vec<Antichain<T>>),
436 input_frontier: (MutableAntichain<T>, Vec<Antichain<T>>),
438 output_frontier: (MutableAntichain<T>, Vec<Antichain<T>>),
440}
441
442impl<T> TrackedFrontiers<T>
443where
444 T: timely::progress::Timestamp + Lattice,
445{
446 fn new(parts: usize) -> Self {
448 #[allow(clippy::as_conversions)]
450 let parts_diff = parts as i64;
451
452 let mut frontier = MutableAntichain::new();
453 frontier.update_iter([(T::minimum(), parts_diff)]);
454 let part_frontiers = vec![Antichain::from_elem(T::minimum()); parts];
455 let frontier_entry = (frontier, part_frontiers);
456
457 Self {
458 write_frontier: frontier_entry.clone(),
459 input_frontier: frontier_entry.clone(),
460 output_frontier: frontier_entry,
461 }
462 }
463
464 fn all_empty(&self) -> bool {
466 self.write_frontier.0.frontier().is_empty()
467 && self.input_frontier.0.frontier().is_empty()
468 && self.output_frontier.0.frontier().is_empty()
469 }
470
471 fn update_write_frontier(
475 &mut self,
476 shard_id: usize,
477 new_shard_frontier: &Antichain<T>,
478 ) -> Option<Antichain<T>> {
479 Self::update_frontier(&mut self.write_frontier, shard_id, new_shard_frontier)
480 }
481
482 fn update_input_frontier(
486 &mut self,
487 shard_id: usize,
488 new_shard_frontier: &Antichain<T>,
489 ) -> Option<Antichain<T>> {
490 Self::update_frontier(&mut self.input_frontier, shard_id, new_shard_frontier)
491 }
492
493 fn update_output_frontier(
497 &mut self,
498 shard_id: usize,
499 new_shard_frontier: &Antichain<T>,
500 ) -> Option<Antichain<T>> {
501 Self::update_frontier(&mut self.output_frontier, shard_id, new_shard_frontier)
502 }
503
504 fn update_frontier(
506 entry: &mut (MutableAntichain<T>, Vec<Antichain<T>>),
507 shard_id: usize,
508 new_shard_frontier: &Antichain<T>,
509 ) -> Option<Antichain<T>> {
510 let (frontier, shard_frontiers) = entry;
511
512 let old_frontier = frontier.frontier().to_owned();
513 let shard_frontier = &mut shard_frontiers[shard_id];
514 frontier.update_iter(shard_frontier.iter().map(|t| (t.clone(), -1)));
515 shard_frontier.join_assign(new_shard_frontier);
516 frontier.update_iter(shard_frontier.iter().map(|t| (t.clone(), 1)));
517
518 let new_frontier = frontier.frontier();
519
520 if PartialOrder::less_than(&old_frontier.borrow(), &new_frontier) {
521 Some(new_frontier.to_owned())
522 } else {
523 None
524 }
525 }
526}
527
528#[derive(Debug)]
529struct PendingSubscribe<T> {
530 frontiers: MutableAntichain<T>,
532 stashed_updates: Result<Vec<(T, Row, Diff)>, String>,
534 stashed_result_size: usize,
536 dropped: bool,
540}
541
542impl<T: ComputeControllerTimestamp> PendingSubscribe<T> {
543 fn new(parts: usize) -> Self {
544 let mut frontiers = MutableAntichain::new();
545 #[allow(clippy::as_conversions)]
547 frontiers.update_iter([(T::minimum(), parts as i64)]);
548
549 Self {
550 frontiers,
551 stashed_updates: Ok(Vec::new()),
552 stashed_result_size: 0,
553 dropped: false,
554 }
555 }
556
557 fn stash(&mut self, new_updates: Result<Vec<(T, Row, Diff)>, String>, max_result_size: u64) {
562 match (&mut self.stashed_updates, new_updates) {
563 (Err(_), _) => {
564 }
567 (_, Err(text)) => {
568 self.stashed_updates = Err(text);
569 }
570 (Ok(stashed), Ok(new)) => {
571 let new_size: usize = new.iter().map(|(_, row, _)| row.byte_len()).sum();
572 self.stashed_result_size += new_size;
573
574 if self.stashed_result_size > max_result_size.cast_into() {
575 self.stashed_updates = Err(format!(
576 "total result exceeds max size of {}",
577 ByteSize::b(max_result_size)
578 ));
579 } else {
580 stashed.extend(new);
581 }
582 }
583 }
584 }
585}
586
587fn merge_peek_responses(
589 resp1: PeekResponse,
590 resp2: PeekResponse,
591 max_result_size: u64,
592) -> PeekResponse {
593 use PeekResponse::*;
594
595 let (resp1, resp2) = match (resp1, resp2) {
597 (Canceled, _) | (_, Canceled) => return Canceled,
598 (Error(e), _) | (_, Error(e)) => return Error(e),
599 resps => resps,
600 };
601
602 let total_byte_len = resp1.inline_byte_len() + resp2.inline_byte_len();
603 if total_byte_len > max_result_size.cast_into() {
604 let err = format!(
607 "total result exceeds max size of {}",
608 ByteSize::b(max_result_size)
609 );
610 return Error(err);
611 }
612
613 match (resp1, resp2) {
614 (Rows(mut rows1), Rows(rows2)) => {
615 rows1.merge(&rows2);
616 Rows(rows1)
617 }
618 (Rows(rows), Stashed(mut stashed)) | (Stashed(mut stashed), Rows(rows)) => {
619 stashed.inline_rows.merge(&rows);
620 Stashed(stashed)
621 }
622 (Stashed(stashed1), Stashed(stashed2)) => {
623 let StashedPeekResponse {
626 num_rows_batches: num_rows_batches1,
627 encoded_size_bytes: encoded_size_bytes1,
628 relation_desc: relation_desc1,
629 shard_id: shard_id1,
630 batches: mut batches1,
631 inline_rows: mut inline_rows1,
632 } = *stashed1;
633 let StashedPeekResponse {
634 num_rows_batches: num_rows_batches2,
635 encoded_size_bytes: encoded_size_bytes2,
636 relation_desc: relation_desc2,
637 shard_id: shard_id2,
638 batches: mut batches2,
639 inline_rows: inline_rows2,
640 } = *stashed2;
641
642 if shard_id1 != shard_id2 {
643 soft_panic_or_log!(
644 "shard IDs of stashed responses do not match: \
645 {shard_id1} != {shard_id2}"
646 );
647 return Error("internal error".into());
648 }
649 if relation_desc1 != relation_desc2 {
650 soft_panic_or_log!(
651 "relation descs of stashed responses do not match: \
652 {relation_desc1:?} != {relation_desc2:?}"
653 );
654 return Error("internal error".into());
655 }
656
657 batches1.append(&mut batches2);
658 inline_rows1.merge(&inline_rows2);
659
660 Stashed(Box::new(StashedPeekResponse {
661 num_rows_batches: num_rows_batches1 + num_rows_batches2,
662 encoded_size_bytes: encoded_size_bytes1 + encoded_size_bytes2,
663 relation_desc: relation_desc1,
664 shard_id: shard_id1,
665 batches: batches1,
666 inline_rows: inline_rows1,
667 }))
668 }
669 _ => unreachable!("handled above"),
670 }
671}