1use std::collections::{BTreeMap, BTreeSet};
13use std::mem;
14
15use async_trait::async_trait;
16use bytesize::ByteSize;
17use differential_dataflow::consolidation::consolidate_updates;
18use differential_dataflow::lattice::Lattice;
19use mz_expr::row::RowCollection;
20use mz_ore::cast::CastInto;
21use mz_ore::soft_panic_or_log;
22use mz_ore::tracing::OpenTelemetryContext;
23use mz_repr::{Diff, GlobalId, Row};
24use mz_service::client::{GenericClient, Partitionable, PartitionedState};
25use timely::PartialOrder;
26use timely::progress::frontier::{Antichain, MutableAntichain};
27use uuid::Uuid;
28
29use crate::controller::ComputeControllerTimestamp;
30use crate::protocol::command::ComputeCommand;
31use crate::protocol::response::{
32 ComputeResponse, CopyToResponse, FrontiersResponse, PeekResponse, StashedPeekResponse,
33 SubscribeBatch, SubscribeResponse,
34};
35
36pub trait ComputeClient<T = mz_repr::Timestamp>:
38 GenericClient<ComputeCommand<T>, ComputeResponse<T>>
39{
40}
41
42impl<C, T> ComputeClient<T> for C where C: GenericClient<ComputeCommand<T>, ComputeResponse<T>> {}
43
44#[async_trait]
45impl<T: Send> GenericClient<ComputeCommand<T>, ComputeResponse<T>> for Box<dyn ComputeClient<T>> {
46 async fn send(&mut self, cmd: ComputeCommand<T>) -> Result<(), anyhow::Error> {
47 (**self).send(cmd).await
48 }
49
50 async fn recv(&mut self) -> Result<Option<ComputeResponse<T>>, anyhow::Error> {
56 (**self).recv().await
58 }
59}
60
61#[derive(Debug)]
84pub struct PartitionedComputeState<T> {
85 parts: usize,
87 max_result_size: u64,
91 frontiers: BTreeMap<GlobalId, TrackedFrontiers<T>>,
104 peek_responses: BTreeMap<Uuid, (PeekResponse, BTreeSet<usize>)>,
115 copy_to_responses: BTreeMap<GlobalId, (CopyToResponse, BTreeSet<usize>)>,
126 pending_subscribes: BTreeMap<GlobalId, PendingSubscribe<T>>,
150}
151
152impl<T> Partitionable<ComputeCommand<T>, ComputeResponse<T>>
153 for (ComputeCommand<T>, ComputeResponse<T>)
154where
155 T: ComputeControllerTimestamp,
156{
157 type PartitionedState = PartitionedComputeState<T>;
158
159 fn new(parts: usize) -> PartitionedComputeState<T> {
160 PartitionedComputeState {
161 parts,
162 max_result_size: u64::MAX,
163 frontiers: BTreeMap::new(),
164 peek_responses: BTreeMap::new(),
165 pending_subscribes: BTreeMap::new(),
166 copy_to_responses: BTreeMap::new(),
167 }
168 }
169}
170
171impl<T> PartitionedComputeState<T>
172where
173 T: ComputeControllerTimestamp,
174{
175 pub fn observe_command(&mut self, command: &ComputeCommand<T>) {
177 match command {
178 ComputeCommand::UpdateConfiguration(config) => {
179 if let Some(max_result_size) = config.max_result_size {
180 self.max_result_size = max_result_size;
181 }
182 }
183 _ => {
184 }
187 }
188 }
189
190 fn absorb_frontiers(
192 &mut self,
193 shard_id: usize,
194 collection_id: GlobalId,
195 frontiers: FrontiersResponse<T>,
196 ) -> Option<ComputeResponse<T>> {
197 let tracked = self
198 .frontiers
199 .entry(collection_id)
200 .or_insert_with(|| TrackedFrontiers::new(self.parts));
201
202 let write_frontier = frontiers
203 .write_frontier
204 .and_then(|f| tracked.update_write_frontier(shard_id, &f));
205 let input_frontier = frontiers
206 .input_frontier
207 .and_then(|f| tracked.update_input_frontier(shard_id, &f));
208 let output_frontier = frontiers
209 .output_frontier
210 .and_then(|f| tracked.update_output_frontier(shard_id, &f));
211
212 let frontiers = FrontiersResponse {
213 write_frontier,
214 input_frontier,
215 output_frontier,
216 };
217 let result = frontiers
218 .has_updates()
219 .then_some(ComputeResponse::Frontiers(collection_id, frontiers));
220
221 if tracked.all_empty() {
222 self.frontiers.remove(&collection_id);
225 }
226
227 result
228 }
229
230 fn absorb_peek_response(
232 &mut self,
233 shard_id: usize,
234 uuid: Uuid,
235 response: PeekResponse,
236 otel_ctx: OpenTelemetryContext,
237 ) -> Option<ComputeResponse<T>> {
238 let (merged, ready_shards) = self.peek_responses.entry(uuid).or_insert((
239 PeekResponse::Rows(RowCollection::default()),
240 BTreeSet::new(),
241 ));
242
243 let first = ready_shards.insert(shard_id);
244 assert!(first, "duplicate peek response");
245
246 let resp1 = mem::replace(merged, PeekResponse::Canceled);
247 *merged = merge_peek_responses(resp1, response, self.max_result_size);
248
249 if ready_shards.len() == self.parts {
250 let (response, _) = self.peek_responses.remove(&uuid).unwrap();
251 Some(ComputeResponse::PeekResponse(uuid, response, otel_ctx))
252 } else {
253 None
254 }
255 }
256
257 fn absorb_copy_to_response(
259 &mut self,
260 shard_id: usize,
261 copyto_id: GlobalId,
262 response: CopyToResponse,
263 ) -> Option<ComputeResponse<T>> {
264 use CopyToResponse::*;
265
266 let (merged, ready_shards) = self
267 .copy_to_responses
268 .entry(copyto_id)
269 .or_insert((CopyToResponse::RowCount(0), BTreeSet::new()));
270
271 let first = ready_shards.insert(shard_id);
272 assert!(first, "duplicate copy-to response");
273
274 let resp1 = mem::replace(merged, Dropped);
275 *merged = match (resp1, response) {
276 (Dropped, _) | (_, Dropped) => Dropped,
277 (Error(e), _) | (_, Error(e)) => Error(e),
278 (RowCount(r1), RowCount(r2)) => RowCount(r1 + r2),
279 };
280
281 if ready_shards.len() == self.parts {
282 let (response, _) = self.copy_to_responses.remove(©to_id).unwrap();
283 Some(ComputeResponse::CopyToResponse(copyto_id, response))
284 } else {
285 None
286 }
287 }
288
289 fn absorb_subscribe_response(
291 &mut self,
292 subscribe_id: GlobalId,
293 response: SubscribeResponse<T>,
294 ) -> Option<ComputeResponse<T>> {
295 let tracked = self
296 .pending_subscribes
297 .entry(subscribe_id)
298 .or_insert_with(|| PendingSubscribe::new(self.parts));
299
300 let emit_response = match response {
301 SubscribeResponse::Batch(batch) => {
302 let frontiers = &mut tracked.frontiers;
303 let old_frontier = frontiers.frontier().to_owned();
304 frontiers.update_iter(batch.lower.into_iter().map(|t| (t, -1)));
305 frontiers.update_iter(batch.upper.into_iter().map(|t| (t, 1)));
306 let new_frontier = frontiers.frontier().to_owned();
307
308 tracked.stash(batch.updates, self.max_result_size);
309
310 if old_frontier != new_frontier && !tracked.dropped {
314 let updates = match &mut tracked.stashed_updates {
315 Ok(stashed_updates) => {
316 consolidate_updates(stashed_updates);
319
320 let mut ship = Vec::new();
321 let mut keep = Vec::new();
322 for (time, data, diff) in stashed_updates.drain(..) {
323 if new_frontier.less_equal(&time) {
324 keep.push((time, data, diff));
325 } else {
326 ship.push((time, data, diff));
327 }
328 }
329 tracked.stashed_updates = Ok(keep);
330 Ok(ship)
331 }
332 Err(text) => Err(text.clone()),
333 };
334 Some(ComputeResponse::SubscribeResponse(
335 subscribe_id,
336 SubscribeResponse::Batch(SubscribeBatch {
337 lower: old_frontier,
338 upper: new_frontier,
339 updates,
340 }),
341 ))
342 } else {
343 None
344 }
345 }
346 SubscribeResponse::DroppedAt(frontier) => {
347 tracked
348 .frontiers
349 .update_iter(frontier.iter().map(|t| (t.clone(), -1)));
350
351 if tracked.dropped {
352 None
353 } else {
354 tracked.dropped = true;
355 Some(ComputeResponse::SubscribeResponse(
356 subscribe_id,
357 SubscribeResponse::DroppedAt(frontier),
358 ))
359 }
360 }
361 };
362
363 if tracked.frontiers.frontier().is_empty() {
364 self.pending_subscribes.remove(&subscribe_id);
367 }
368
369 emit_response
370 }
371}
372
373impl<T> PartitionedState<ComputeCommand<T>, ComputeResponse<T>> for PartitionedComputeState<T>
374where
375 T: ComputeControllerTimestamp,
376{
377 fn split_command(&mut self, command: ComputeCommand<T>) -> Vec<Option<ComputeCommand<T>>> {
378 self.observe_command(&command);
379
380 match command {
384 command @ ComputeCommand::Hello { .. }
385 | command @ ComputeCommand::UpdateConfiguration(_) => {
386 vec![Some(command); self.parts]
387 }
388 command => {
389 let mut r = vec![None; self.parts];
390 r[0] = Some(command);
391 r
392 }
393 }
394 }
395
396 fn absorb_response(
397 &mut self,
398 shard_id: usize,
399 message: ComputeResponse<T>,
400 ) -> Option<Result<ComputeResponse<T>, anyhow::Error>> {
401 let response = match message {
402 ComputeResponse::Frontiers(id, frontiers) => {
403 self.absorb_frontiers(shard_id, id, frontiers)
404 }
405 ComputeResponse::PeekResponse(uuid, response, otel_ctx) => {
406 self.absorb_peek_response(shard_id, uuid, response, otel_ctx)
407 }
408 ComputeResponse::SubscribeResponse(id, response) => {
409 self.absorb_subscribe_response(id, response)
410 }
411 ComputeResponse::CopyToResponse(id, response) => {
412 self.absorb_copy_to_response(shard_id, id, response)
413 }
414 response @ ComputeResponse::Status(_) => {
415 Some(response)
417 }
418 };
419
420 response.map(Ok)
421 }
422}
423
424#[derive(Debug)]
429struct TrackedFrontiers<T> {
430 write_frontier: (MutableAntichain<T>, Vec<Antichain<T>>),
432 input_frontier: (MutableAntichain<T>, Vec<Antichain<T>>),
434 output_frontier: (MutableAntichain<T>, Vec<Antichain<T>>),
436}
437
438impl<T> TrackedFrontiers<T>
439where
440 T: timely::progress::Timestamp + Lattice,
441{
442 fn new(parts: usize) -> Self {
444 #[allow(clippy::as_conversions)]
446 let parts_diff = parts as i64;
447
448 let mut frontier = MutableAntichain::new();
449 frontier.update_iter([(T::minimum(), parts_diff)]);
450 let part_frontiers = vec![Antichain::from_elem(T::minimum()); parts];
451 let frontier_entry = (frontier, part_frontiers);
452
453 Self {
454 write_frontier: frontier_entry.clone(),
455 input_frontier: frontier_entry.clone(),
456 output_frontier: frontier_entry,
457 }
458 }
459
460 fn all_empty(&self) -> bool {
462 self.write_frontier.0.frontier().is_empty()
463 && self.input_frontier.0.frontier().is_empty()
464 && self.output_frontier.0.frontier().is_empty()
465 }
466
467 fn update_write_frontier(
471 &mut self,
472 shard_id: usize,
473 new_shard_frontier: &Antichain<T>,
474 ) -> Option<Antichain<T>> {
475 Self::update_frontier(&mut self.write_frontier, shard_id, new_shard_frontier)
476 }
477
478 fn update_input_frontier(
482 &mut self,
483 shard_id: usize,
484 new_shard_frontier: &Antichain<T>,
485 ) -> Option<Antichain<T>> {
486 Self::update_frontier(&mut self.input_frontier, shard_id, new_shard_frontier)
487 }
488
489 fn update_output_frontier(
493 &mut self,
494 shard_id: usize,
495 new_shard_frontier: &Antichain<T>,
496 ) -> Option<Antichain<T>> {
497 Self::update_frontier(&mut self.output_frontier, shard_id, new_shard_frontier)
498 }
499
500 fn update_frontier(
502 entry: &mut (MutableAntichain<T>, Vec<Antichain<T>>),
503 shard_id: usize,
504 new_shard_frontier: &Antichain<T>,
505 ) -> Option<Antichain<T>> {
506 let (frontier, shard_frontiers) = entry;
507
508 let old_frontier = frontier.frontier().to_owned();
509 let shard_frontier = &mut shard_frontiers[shard_id];
510 frontier.update_iter(shard_frontier.iter().map(|t| (t.clone(), -1)));
511 shard_frontier.join_assign(new_shard_frontier);
512 frontier.update_iter(shard_frontier.iter().map(|t| (t.clone(), 1)));
513
514 let new_frontier = frontier.frontier();
515
516 if PartialOrder::less_than(&old_frontier.borrow(), &new_frontier) {
517 Some(new_frontier.to_owned())
518 } else {
519 None
520 }
521 }
522}
523
524#[derive(Debug)]
525struct PendingSubscribe<T> {
526 frontiers: MutableAntichain<T>,
528 stashed_updates: Result<Vec<(T, Row, Diff)>, String>,
530 stashed_result_size: usize,
532 dropped: bool,
536}
537
538impl<T: ComputeControllerTimestamp> PendingSubscribe<T> {
539 fn new(parts: usize) -> Self {
540 let mut frontiers = MutableAntichain::new();
541 #[allow(clippy::as_conversions)]
543 frontiers.update_iter([(T::minimum(), parts as i64)]);
544
545 Self {
546 frontiers,
547 stashed_updates: Ok(Vec::new()),
548 stashed_result_size: 0,
549 dropped: false,
550 }
551 }
552
553 fn stash(&mut self, new_updates: Result<Vec<(T, Row, Diff)>, String>, max_result_size: u64) {
558 match (&mut self.stashed_updates, new_updates) {
559 (Err(_), _) => {
560 }
563 (_, Err(text)) => {
564 self.stashed_updates = Err(text);
565 }
566 (Ok(stashed), Ok(new)) => {
567 let new_size: usize = new.iter().map(|(_, row, _)| row.byte_len()).sum();
568 self.stashed_result_size += new_size;
569
570 if self.stashed_result_size > max_result_size.cast_into() {
571 self.stashed_updates = Err(format!(
572 "total result exceeds max size of {}",
573 ByteSize::b(max_result_size)
574 ));
575 } else {
576 stashed.extend(new);
577 }
578 }
579 }
580 }
581}
582
583fn merge_peek_responses(
585 resp1: PeekResponse,
586 resp2: PeekResponse,
587 max_result_size: u64,
588) -> PeekResponse {
589 use PeekResponse::*;
590
591 let (resp1, resp2) = match (resp1, resp2) {
593 (Canceled, _) | (_, Canceled) => return Canceled,
594 (Error(e), _) | (_, Error(e)) => return Error(e),
595 resps => resps,
596 };
597
598 let total_byte_len = resp1.inline_byte_len() + resp2.inline_byte_len();
599 if total_byte_len > max_result_size.cast_into() {
600 let err = format!(
603 "total result exceeds max size of {}",
604 ByteSize::b(max_result_size)
605 );
606 return Error(err);
607 }
608
609 match (resp1, resp2) {
610 (Rows(mut rows1), Rows(rows2)) => {
611 rows1.merge(&rows2);
612 Rows(rows1)
613 }
614 (Rows(rows), Stashed(mut stashed)) | (Stashed(mut stashed), Rows(rows)) => {
615 stashed.inline_rows.merge(&rows);
616 Stashed(stashed)
617 }
618 (Stashed(stashed1), Stashed(stashed2)) => {
619 let StashedPeekResponse {
622 num_rows_batches: num_rows_batches1,
623 encoded_size_bytes: encoded_size_bytes1,
624 relation_desc: relation_desc1,
625 shard_id: shard_id1,
626 batches: mut batches1,
627 inline_rows: mut inline_rows1,
628 } = *stashed1;
629 let StashedPeekResponse {
630 num_rows_batches: num_rows_batches2,
631 encoded_size_bytes: encoded_size_bytes2,
632 relation_desc: relation_desc2,
633 shard_id: shard_id2,
634 batches: mut batches2,
635 inline_rows: inline_rows2,
636 } = *stashed2;
637
638 if shard_id1 != shard_id2 {
639 soft_panic_or_log!(
640 "shard IDs of stashed responses do not match: \
641 {shard_id1} != {shard_id2}"
642 );
643 return Error("internal error".into());
644 }
645 if relation_desc1 != relation_desc2 {
646 soft_panic_or_log!(
647 "relation descs of stashed responses do not match: \
648 {relation_desc1:?} != {relation_desc2:?}"
649 );
650 return Error("internal error".into());
651 }
652
653 batches1.append(&mut batches2);
654 inline_rows1.merge(&inline_rows2);
655
656 Stashed(Box::new(StashedPeekResponse {
657 num_rows_batches: num_rows_batches1 + num_rows_batches2,
658 encoded_size_bytes: encoded_size_bytes1 + encoded_size_bytes2,
659 relation_desc: relation_desc1,
660 shard_id: shard_id1,
661 batches: batches1,
662 inline_rows: inline_rows1,
663 }))
664 }
665 _ => unreachable!("handled above"),
666 }
667}