1use std::collections::{BTreeMap, BTreeSet};
13use std::mem;
14
15use async_trait::async_trait;
16use bytesize::ByteSize;
17use differential_dataflow::lattice::Lattice;
18use mz_expr::row::RowCollection;
19use mz_ore::cast::CastInto;
20use mz_ore::soft_panic_or_log;
21use mz_ore::tracing::OpenTelemetryContext;
22use mz_repr::{GlobalId, UpdateCollection};
23use mz_service::client::{GenericClient, Partitionable, PartitionedState};
24use timely::PartialOrder;
25use timely::progress::frontier::{Antichain, MutableAntichain};
26use uuid::Uuid;
27
28use crate::controller::ComputeControllerTimestamp;
29use crate::protocol::command::ComputeCommand;
30use crate::protocol::response::{
31 ComputeResponse, CopyToResponse, FrontiersResponse, PeekResponse, StashedPeekResponse,
32 SubscribeBatch, SubscribeResponse,
33};
34
35pub trait ComputeClient<T = mz_repr::Timestamp>:
37 GenericClient<ComputeCommand<T>, ComputeResponse<T>>
38{
39}
40
41impl<C, T> ComputeClient<T> for C where C: GenericClient<ComputeCommand<T>, ComputeResponse<T>> {}
42
43#[async_trait]
44impl<T: Send> GenericClient<ComputeCommand<T>, ComputeResponse<T>> for Box<dyn ComputeClient<T>> {
45 async fn send(&mut self, cmd: ComputeCommand<T>) -> Result<(), anyhow::Error> {
46 (**self).send(cmd).await
47 }
48
49 async fn recv(&mut self) -> Result<Option<ComputeResponse<T>>, anyhow::Error> {
55 (**self).recv().await
57 }
58}
59
60#[derive(Debug)]
83pub struct PartitionedComputeState<T> {
84 parts: usize,
86 max_result_size: u64,
90 frontiers: BTreeMap<GlobalId, TrackedFrontiers<T>>,
103 peek_responses: BTreeMap<Uuid, (PeekResponse, BTreeSet<usize>)>,
114 copy_to_responses: BTreeMap<GlobalId, (CopyToResponse, BTreeSet<usize>)>,
125 pending_subscribes: BTreeMap<GlobalId, PendingSubscribe<T>>,
149}
150
151impl<T> Partitionable<ComputeCommand<T>, ComputeResponse<T>>
152 for (ComputeCommand<T>, ComputeResponse<T>)
153where
154 T: ComputeControllerTimestamp,
155{
156 type PartitionedState = PartitionedComputeState<T>;
157
158 fn new(parts: usize) -> PartitionedComputeState<T> {
159 PartitionedComputeState {
160 parts,
161 max_result_size: u64::MAX,
162 frontiers: BTreeMap::new(),
163 peek_responses: BTreeMap::new(),
164 pending_subscribes: BTreeMap::new(),
165 copy_to_responses: BTreeMap::new(),
166 }
167 }
168}
169
170impl<T> PartitionedComputeState<T>
171where
172 T: ComputeControllerTimestamp,
173{
174 pub fn observe_command(&mut self, command: &ComputeCommand<T>) {
176 match command {
177 ComputeCommand::UpdateConfiguration(config) => {
178 if let Some(max_result_size) = config.max_result_size {
179 self.max_result_size = max_result_size;
180 }
181 }
182 _ => {
183 }
186 }
187 }
188
189 fn absorb_frontiers(
191 &mut self,
192 shard_id: usize,
193 collection_id: GlobalId,
194 frontiers: FrontiersResponse<T>,
195 ) -> Option<ComputeResponse<T>> {
196 let tracked = self
197 .frontiers
198 .entry(collection_id)
199 .or_insert_with(|| TrackedFrontiers::new(self.parts));
200
201 let write_frontier = frontiers
202 .write_frontier
203 .and_then(|f| tracked.update_write_frontier(shard_id, &f));
204 let input_frontier = frontiers
205 .input_frontier
206 .and_then(|f| tracked.update_input_frontier(shard_id, &f));
207 let output_frontier = frontiers
208 .output_frontier
209 .and_then(|f| tracked.update_output_frontier(shard_id, &f));
210
211 let frontiers = FrontiersResponse {
212 write_frontier,
213 input_frontier,
214 output_frontier,
215 };
216 let result = frontiers
217 .has_updates()
218 .then_some(ComputeResponse::Frontiers(collection_id, frontiers));
219
220 if tracked.all_empty() {
221 self.frontiers.remove(&collection_id);
224 }
225
226 result
227 }
228
229 fn absorb_peek_response(
231 &mut self,
232 shard_id: usize,
233 uuid: Uuid,
234 response: PeekResponse,
235 otel_ctx: OpenTelemetryContext,
236 ) -> Option<ComputeResponse<T>> {
237 let (merged, ready_shards) = self.peek_responses.entry(uuid).or_insert((
238 PeekResponse::Rows(vec![RowCollection::default()]),
239 BTreeSet::new(),
240 ));
241
242 let first = ready_shards.insert(shard_id);
243 assert!(first, "duplicate peek response");
244
245 let resp1 = mem::replace(merged, PeekResponse::Canceled);
246 *merged = merge_peek_responses(resp1, response, self.max_result_size);
247
248 if ready_shards.len() == self.parts {
249 let (response, _) = self.peek_responses.remove(&uuid).unwrap();
250 Some(ComputeResponse::PeekResponse(uuid, response, otel_ctx))
251 } else {
252 None
253 }
254 }
255
256 fn absorb_copy_to_response(
258 &mut self,
259 shard_id: usize,
260 copyto_id: GlobalId,
261 response: CopyToResponse,
262 ) -> Option<ComputeResponse<T>> {
263 use CopyToResponse::*;
264
265 let (merged, ready_shards) = self
266 .copy_to_responses
267 .entry(copyto_id)
268 .or_insert((CopyToResponse::RowCount(0), BTreeSet::new()));
269
270 let first = ready_shards.insert(shard_id);
271 assert!(first, "duplicate copy-to response");
272
273 let resp1 = mem::replace(merged, Dropped);
274 *merged = match (resp1, response) {
275 (Dropped, _) | (_, Dropped) => Dropped,
276 (Error(e), _) | (_, Error(e)) => Error(e),
277 (RowCount(r1), RowCount(r2)) => RowCount(r1 + r2),
278 };
279
280 if ready_shards.len() == self.parts {
281 let (response, _) = self.copy_to_responses.remove(©to_id).unwrap();
282 Some(ComputeResponse::CopyToResponse(copyto_id, response))
283 } else {
284 None
285 }
286 }
287
288 fn absorb_subscribe_response(
290 &mut self,
291 subscribe_id: GlobalId,
292 response: SubscribeResponse<T>,
293 ) -> Option<ComputeResponse<T>> {
294 let tracked = self
295 .pending_subscribes
296 .entry(subscribe_id)
297 .or_insert_with(|| PendingSubscribe::new(self.parts));
298
299 let emit_response = match response {
300 SubscribeResponse::Batch(batch) => {
301 let frontiers = &mut tracked.frontiers;
302 let old_frontier = frontiers.frontier().to_owned();
303 frontiers.update_iter(batch.lower.into_iter().map(|t| (t, -1)));
304 frontiers.update_iter(batch.upper.into_iter().map(|t| (t, 1)));
305 let new_frontier = frontiers.frontier().to_owned();
306
307 tracked.stash(batch.updates, self.max_result_size);
308
309 if old_frontier != new_frontier && !tracked.dropped {
313 let updates = match &mut tracked.stashed_updates {
314 Ok(stashed_updates) => {
315 let mut ship = vec![];
317 let mut keep = vec![];
318 for collection in stashed_updates.drain(..) {
319 let partition_point = collection
320 .times()
321 .partition_point(|t| !new_frontier.less_equal(t));
322 let (ship_coll, keep_coll) = collection.split_at(partition_point);
323 if ship_coll.len() > 0 {
324 ship.push(ship_coll);
325 }
326 if keep_coll.len() > 0 {
327 keep.push(keep_coll);
328 }
329 }
330 tracked.stashed_updates = Ok(keep);
331 Ok(ship)
332 }
333 Err(text) => Err(text.clone()),
334 };
335 Some(ComputeResponse::SubscribeResponse(
336 subscribe_id,
337 SubscribeResponse::Batch(SubscribeBatch {
338 lower: old_frontier,
339 upper: new_frontier,
340 updates,
341 }),
342 ))
343 } else {
344 None
345 }
346 }
347 SubscribeResponse::DroppedAt(frontier) => {
348 tracked
349 .frontiers
350 .update_iter(frontier.iter().map(|t| (t.clone(), -1)));
351
352 if tracked.dropped {
353 None
354 } else {
355 tracked.dropped = true;
356 Some(ComputeResponse::SubscribeResponse(
357 subscribe_id,
358 SubscribeResponse::DroppedAt(frontier),
359 ))
360 }
361 }
362 };
363
364 if tracked.frontiers.frontier().is_empty() {
365 self.pending_subscribes.remove(&subscribe_id);
368 }
369
370 emit_response
371 }
372}
373
374impl<T> PartitionedState<ComputeCommand<T>, ComputeResponse<T>> for PartitionedComputeState<T>
375where
376 T: ComputeControllerTimestamp,
377{
378 fn split_command(&mut self, command: ComputeCommand<T>) -> Vec<Option<ComputeCommand<T>>> {
379 self.observe_command(&command);
380
381 match command {
385 command @ ComputeCommand::Hello { .. }
386 | command @ ComputeCommand::UpdateConfiguration(_) => {
387 vec![Some(command); self.parts]
388 }
389 command => {
390 let mut r = vec![None; self.parts];
391 r[0] = Some(command);
392 r
393 }
394 }
395 }
396
397 fn absorb_response(
398 &mut self,
399 shard_id: usize,
400 message: ComputeResponse<T>,
401 ) -> Option<Result<ComputeResponse<T>, anyhow::Error>> {
402 let response = match message {
403 ComputeResponse::Frontiers(id, frontiers) => {
404 self.absorb_frontiers(shard_id, id, frontiers)
405 }
406 ComputeResponse::PeekResponse(uuid, response, otel_ctx) => {
407 self.absorb_peek_response(shard_id, uuid, response, otel_ctx)
408 }
409 ComputeResponse::SubscribeResponse(id, response) => {
410 self.absorb_subscribe_response(id, response)
411 }
412 ComputeResponse::CopyToResponse(id, response) => {
413 self.absorb_copy_to_response(shard_id, id, response)
414 }
415 response @ ComputeResponse::Status(_) => {
416 Some(response)
418 }
419 };
420
421 response.map(Ok)
422 }
423}
424
425#[derive(Debug)]
430struct TrackedFrontiers<T> {
431 write_frontier: (MutableAntichain<T>, Vec<Antichain<T>>),
433 input_frontier: (MutableAntichain<T>, Vec<Antichain<T>>),
435 output_frontier: (MutableAntichain<T>, Vec<Antichain<T>>),
437}
438
439impl<T> TrackedFrontiers<T>
440where
441 T: timely::progress::Timestamp + Lattice,
442{
443 fn new(parts: usize) -> Self {
445 #[allow(clippy::as_conversions)]
447 let parts_diff = parts as i64;
448
449 let mut frontier = MutableAntichain::new();
450 frontier.update_iter([(T::minimum(), parts_diff)]);
451 let part_frontiers = vec![Antichain::from_elem(T::minimum()); parts];
452 let frontier_entry = (frontier, part_frontiers);
453
454 Self {
455 write_frontier: frontier_entry.clone(),
456 input_frontier: frontier_entry.clone(),
457 output_frontier: frontier_entry,
458 }
459 }
460
461 fn all_empty(&self) -> bool {
463 self.write_frontier.0.frontier().is_empty()
464 && self.input_frontier.0.frontier().is_empty()
465 && self.output_frontier.0.frontier().is_empty()
466 }
467
468 fn update_write_frontier(
472 &mut self,
473 shard_id: usize,
474 new_shard_frontier: &Antichain<T>,
475 ) -> Option<Antichain<T>> {
476 Self::update_frontier(&mut self.write_frontier, shard_id, new_shard_frontier)
477 }
478
479 fn update_input_frontier(
483 &mut self,
484 shard_id: usize,
485 new_shard_frontier: &Antichain<T>,
486 ) -> Option<Antichain<T>> {
487 Self::update_frontier(&mut self.input_frontier, shard_id, new_shard_frontier)
488 }
489
490 fn update_output_frontier(
494 &mut self,
495 shard_id: usize,
496 new_shard_frontier: &Antichain<T>,
497 ) -> Option<Antichain<T>> {
498 Self::update_frontier(&mut self.output_frontier, shard_id, new_shard_frontier)
499 }
500
501 fn update_frontier(
503 entry: &mut (MutableAntichain<T>, Vec<Antichain<T>>),
504 shard_id: usize,
505 new_shard_frontier: &Antichain<T>,
506 ) -> Option<Antichain<T>> {
507 let (frontier, shard_frontiers) = entry;
508
509 let old_frontier = frontier.frontier().to_owned();
510 let shard_frontier = &mut shard_frontiers[shard_id];
511 frontier.update_iter(shard_frontier.iter().map(|t| (t.clone(), -1)));
512 shard_frontier.join_assign(new_shard_frontier);
513 frontier.update_iter(shard_frontier.iter().map(|t| (t.clone(), 1)));
514
515 let new_frontier = frontier.frontier();
516
517 if PartialOrder::less_than(&old_frontier.borrow(), &new_frontier) {
518 Some(new_frontier.to_owned())
519 } else {
520 None
521 }
522 }
523}
524
525#[derive(Debug)]
526struct PendingSubscribe<T> {
527 frontiers: MutableAntichain<T>,
529 stashed_updates: Result<Vec<UpdateCollection<T>>, String>,
531 stashed_result_size: usize,
533 dropped: bool,
537}
538
539impl<T: ComputeControllerTimestamp> PendingSubscribe<T> {
540 fn new(parts: usize) -> Self {
541 let mut frontiers = MutableAntichain::new();
542 #[allow(clippy::as_conversions)]
544 frontiers.update_iter([(T::minimum(), parts as i64)]);
545
546 Self {
547 frontiers,
548 stashed_updates: Ok(Vec::new()),
549 stashed_result_size: 0,
550 dropped: false,
551 }
552 }
553
554 fn stash(
559 &mut self,
560 new_updates: Result<Vec<UpdateCollection<T>>, String>,
561 max_result_size: u64,
562 ) {
563 match (&mut self.stashed_updates, new_updates) {
564 (Err(_), _) => {
565 }
568 (_, Err(text)) => {
569 self.stashed_updates = Err(text);
570 }
571 (Ok(stashed), Ok(new)) => {
572 let new_size: usize = new.iter().map(|coll| coll.byte_len()).sum();
573 self.stashed_result_size += new_size;
574
575 if self.stashed_result_size > max_result_size.cast_into() {
576 self.stashed_updates = Err(format!(
577 "total result exceeds max size of {}",
578 ByteSize::b(max_result_size)
579 ));
580 } else {
581 stashed.extend(new);
582 }
583 }
584 }
585 }
586}
587
588fn merge_peek_responses(
590 resp1: PeekResponse,
591 resp2: PeekResponse,
592 max_result_size: u64,
593) -> PeekResponse {
594 use PeekResponse::*;
595
596 let (resp1, resp2) = match (resp1, resp2) {
598 (Canceled, _) | (_, Canceled) => return Canceled,
599 (Error(e), _) | (_, Error(e)) => return Error(e),
600 resps => resps,
601 };
602
603 let total_byte_len = resp1.inline_byte_len() + resp2.inline_byte_len();
604 if total_byte_len > max_result_size.cast_into() {
605 let err = format!(
608 "total result exceeds max size of {}",
609 ByteSize::b(max_result_size)
610 );
611 return Error(err);
612 }
613
614 match (resp1, resp2) {
615 (Rows(mut rows1), Rows(rows2)) => {
616 rows1.extend(rows2);
617 Rows(rows1)
618 }
619 (Rows(rows), Stashed(mut stashed)) | (Stashed(mut stashed), Rows(rows)) => {
620 stashed.inline_rows.extend(rows);
621 Stashed(stashed)
622 }
623 (Stashed(stashed1), Stashed(stashed2)) => {
624 let StashedPeekResponse {
627 num_rows_batches: num_rows_batches1,
628 encoded_size_bytes: encoded_size_bytes1,
629 relation_desc: relation_desc1,
630 shard_id: shard_id1,
631 batches: mut batches1,
632 inline_rows: mut inline_rows1,
633 } = *stashed1;
634 let StashedPeekResponse {
635 num_rows_batches: num_rows_batches2,
636 encoded_size_bytes: encoded_size_bytes2,
637 relation_desc: relation_desc2,
638 shard_id: shard_id2,
639 batches: mut batches2,
640 inline_rows: inline_rows2,
641 } = *stashed2;
642
643 if shard_id1 != shard_id2 {
644 soft_panic_or_log!(
645 "shard IDs of stashed responses do not match: \
646 {shard_id1} != {shard_id2}"
647 );
648 return Error("internal error".into());
649 }
650 if relation_desc1 != relation_desc2 {
651 soft_panic_or_log!(
652 "relation descs of stashed responses do not match: \
653 {relation_desc1:?} != {relation_desc2:?}"
654 );
655 return Error("internal error".into());
656 }
657
658 batches1.append(&mut batches2);
659 inline_rows1.extend(inline_rows2);
660
661 Stashed(Box::new(StashedPeekResponse {
662 num_rows_batches: num_rows_batches1 + num_rows_batches2,
663 encoded_size_bytes: encoded_size_bytes1 + encoded_size_bytes2,
664 relation_desc: relation_desc1,
665 shard_id: shard_id1,
666 batches: batches1,
667 inline_rows: inline_rows1,
668 }))
669 }
670 _ => unreachable!("handled above"),
671 }
672}