1#![allow(clippy::as_conversions, clippy::clone_on_ref_ptr)]
13
14use std::collections::{BTreeMap, BTreeSet};
17use std::mem;
18
19use async_trait::async_trait;
20use bytesize::ByteSize;
21use differential_dataflow::consolidation::consolidate_updates;
22use differential_dataflow::lattice::Lattice;
23use mz_expr::row::RowCollection;
24use mz_ore::cast::CastInto;
25use mz_ore::soft_panic_or_log;
26use mz_ore::tracing::OpenTelemetryContext;
27use mz_repr::{Diff, GlobalId, Row};
28use mz_service::client::{GenericClient, Partitionable, PartitionedState};
29use mz_service::grpc::{GrpcClient, GrpcServer, ProtoServiceTypes, ResponseStream};
30use timely::PartialOrder;
31use timely::progress::frontier::{Antichain, MutableAntichain};
32use tonic::{Request, Status, Streaming};
33use uuid::Uuid;
34
35use crate::controller::ComputeControllerTimestamp;
36use crate::metrics::ReplicaMetrics;
37use crate::protocol::command::{ComputeCommand, ProtoComputeCommand};
38use crate::protocol::response::{
39 ComputeResponse, CopyToResponse, FrontiersResponse, PeekResponse, ProtoComputeResponse,
40 StashedPeekResponse, SubscribeBatch, SubscribeResponse,
41};
42use crate::service::proto_compute_server::ProtoCompute;
43
44include!(concat!(env!("OUT_DIR"), "/mz_compute_client.service.rs"));
45
46pub trait ComputeClient<T = mz_repr::Timestamp>:
48 GenericClient<ComputeCommand<T>, ComputeResponse<T>>
49{
50}
51
52impl<C, T> ComputeClient<T> for C where C: GenericClient<ComputeCommand<T>, ComputeResponse<T>> {}
53
54#[async_trait]
55impl<T: Send> GenericClient<ComputeCommand<T>, ComputeResponse<T>> for Box<dyn ComputeClient<T>> {
56 async fn send(&mut self, cmd: ComputeCommand<T>) -> Result<(), anyhow::Error> {
57 (**self).send(cmd).await
58 }
59
60 async fn recv(&mut self) -> Result<Option<ComputeResponse<T>>, anyhow::Error> {
66 (**self).recv().await
68 }
69}
70
71#[derive(Debug, Clone)]
73pub enum ComputeProtoServiceTypes {}
74
75impl ProtoServiceTypes for ComputeProtoServiceTypes {
76 type PC = ProtoComputeCommand;
77 type PR = ProtoComputeResponse;
78 type STATS = ReplicaMetrics;
79 const URL: &'static str = "/mz_compute_client.service.ProtoCompute/CommandResponseStream";
80}
81
82pub type ComputeGrpcClient = GrpcClient<ComputeProtoServiceTypes>;
84
85#[async_trait]
86impl<F, G> ProtoCompute for GrpcServer<F>
87where
88 F: Fn() -> G + Send + Sync + 'static,
89 G: ComputeClient + 'static,
90{
91 type CommandResponseStreamStream = ResponseStream<ProtoComputeResponse>;
92
93 async fn command_response_stream(
94 &self,
95 request: Request<Streaming<ProtoComputeCommand>>,
96 ) -> Result<tonic::Response<Self::CommandResponseStreamStream>, Status> {
97 self.forward_bidi_stream(request).await
98 }
99}
100
101#[derive(Debug)]
124pub struct PartitionedComputeState<T> {
125 parts: usize,
127 max_result_size: u64,
131 frontiers: BTreeMap<GlobalId, TrackedFrontiers<T>>,
144 peek_responses: BTreeMap<Uuid, (PeekResponse, BTreeSet<usize>)>,
155 copy_to_responses: BTreeMap<GlobalId, (CopyToResponse, BTreeSet<usize>)>,
166 pending_subscribes: BTreeMap<GlobalId, PendingSubscribe<T>>,
190}
191
192impl<T> Partitionable<ComputeCommand<T>, ComputeResponse<T>>
193 for (ComputeCommand<T>, ComputeResponse<T>)
194where
195 T: ComputeControllerTimestamp,
196{
197 type PartitionedState = PartitionedComputeState<T>;
198
199 fn new(parts: usize) -> PartitionedComputeState<T> {
200 PartitionedComputeState {
201 parts,
202 max_result_size: u64::MAX,
203 frontiers: BTreeMap::new(),
204 peek_responses: BTreeMap::new(),
205 pending_subscribes: BTreeMap::new(),
206 copy_to_responses: BTreeMap::new(),
207 }
208 }
209}
210
211impl<T> PartitionedComputeState<T>
212where
213 T: ComputeControllerTimestamp,
214{
215 pub fn observe_command(&mut self, command: &ComputeCommand<T>) {
217 match command {
218 ComputeCommand::UpdateConfiguration(config) => {
219 if let Some(max_result_size) = config.max_result_size {
220 self.max_result_size = max_result_size;
221 }
222 }
223 _ => {
224 }
227 }
228 }
229
230 fn absorb_frontiers(
232 &mut self,
233 shard_id: usize,
234 collection_id: GlobalId,
235 frontiers: FrontiersResponse<T>,
236 ) -> Option<ComputeResponse<T>> {
237 let tracked = self
238 .frontiers
239 .entry(collection_id)
240 .or_insert_with(|| TrackedFrontiers::new(self.parts));
241
242 let write_frontier = frontiers
243 .write_frontier
244 .and_then(|f| tracked.update_write_frontier(shard_id, &f));
245 let input_frontier = frontiers
246 .input_frontier
247 .and_then(|f| tracked.update_input_frontier(shard_id, &f));
248 let output_frontier = frontiers
249 .output_frontier
250 .and_then(|f| tracked.update_output_frontier(shard_id, &f));
251
252 let frontiers = FrontiersResponse {
253 write_frontier,
254 input_frontier,
255 output_frontier,
256 };
257 let result = frontiers
258 .has_updates()
259 .then_some(ComputeResponse::Frontiers(collection_id, frontiers));
260
261 if tracked.all_empty() {
262 self.frontiers.remove(&collection_id);
265 }
266
267 result
268 }
269
270 fn absorb_peek_response(
272 &mut self,
273 shard_id: usize,
274 uuid: Uuid,
275 response: PeekResponse,
276 otel_ctx: OpenTelemetryContext,
277 ) -> Option<ComputeResponse<T>> {
278 let (merged, ready_shards) = self.peek_responses.entry(uuid).or_insert((
279 PeekResponse::Rows(RowCollection::default()),
280 BTreeSet::new(),
281 ));
282
283 let first = ready_shards.insert(shard_id);
284 assert!(first, "duplicate peek response");
285
286 let resp1 = mem::replace(merged, PeekResponse::Canceled);
287 *merged = merge_peek_responses(resp1, response, self.max_result_size);
288
289 if ready_shards.len() == self.parts {
290 let (response, _) = self.peek_responses.remove(&uuid).unwrap();
291 Some(ComputeResponse::PeekResponse(uuid, response, otel_ctx))
292 } else {
293 None
294 }
295 }
296
297 fn absorb_copy_to_response(
299 &mut self,
300 shard_id: usize,
301 copyto_id: GlobalId,
302 response: CopyToResponse,
303 ) -> Option<ComputeResponse<T>> {
304 use CopyToResponse::*;
305
306 let (merged, ready_shards) = self
307 .copy_to_responses
308 .entry(copyto_id)
309 .or_insert((CopyToResponse::RowCount(0), BTreeSet::new()));
310
311 let first = ready_shards.insert(shard_id);
312 assert!(first, "duplicate copy-to response");
313
314 let resp1 = mem::replace(merged, Dropped);
315 *merged = match (resp1, response) {
316 (Dropped, _) | (_, Dropped) => Dropped,
317 (Error(e), _) | (_, Error(e)) => Error(e),
318 (RowCount(r1), RowCount(r2)) => RowCount(r1 + r2),
319 };
320
321 if ready_shards.len() == self.parts {
322 let (response, _) = self.copy_to_responses.remove(©to_id).unwrap();
323 Some(ComputeResponse::CopyToResponse(copyto_id, response))
324 } else {
325 None
326 }
327 }
328
329 fn absorb_subscribe_response(
331 &mut self,
332 subscribe_id: GlobalId,
333 response: SubscribeResponse<T>,
334 ) -> Option<ComputeResponse<T>> {
335 let tracked = self
336 .pending_subscribes
337 .entry(subscribe_id)
338 .or_insert_with(|| PendingSubscribe::new(self.parts));
339
340 let emit_response = match response {
341 SubscribeResponse::Batch(batch) => {
342 let frontiers = &mut tracked.frontiers;
343 let old_frontier = frontiers.frontier().to_owned();
344 frontiers.update_iter(batch.lower.into_iter().map(|t| (t, -1)));
345 frontiers.update_iter(batch.upper.into_iter().map(|t| (t, 1)));
346 let new_frontier = frontiers.frontier().to_owned();
347
348 tracked.stash(batch.updates, self.max_result_size);
349
350 if old_frontier != new_frontier && !tracked.dropped {
354 let updates = match &mut tracked.stashed_updates {
355 Ok(stashed_updates) => {
356 consolidate_updates(stashed_updates);
359
360 let mut ship = Vec::new();
361 let mut keep = Vec::new();
362 for (time, data, diff) in stashed_updates.drain(..) {
363 if new_frontier.less_equal(&time) {
364 keep.push((time, data, diff));
365 } else {
366 ship.push((time, data, diff));
367 }
368 }
369 tracked.stashed_updates = Ok(keep);
370 Ok(ship)
371 }
372 Err(text) => Err(text.clone()),
373 };
374 Some(ComputeResponse::SubscribeResponse(
375 subscribe_id,
376 SubscribeResponse::Batch(SubscribeBatch {
377 lower: old_frontier,
378 upper: new_frontier,
379 updates,
380 }),
381 ))
382 } else {
383 None
384 }
385 }
386 SubscribeResponse::DroppedAt(frontier) => {
387 tracked
388 .frontiers
389 .update_iter(frontier.iter().map(|t| (t.clone(), -1)));
390
391 if tracked.dropped {
392 None
393 } else {
394 tracked.dropped = true;
395 Some(ComputeResponse::SubscribeResponse(
396 subscribe_id,
397 SubscribeResponse::DroppedAt(frontier),
398 ))
399 }
400 }
401 };
402
403 if tracked.frontiers.frontier().is_empty() {
404 self.pending_subscribes.remove(&subscribe_id);
407 }
408
409 emit_response
410 }
411}
412
413impl<T> PartitionedState<ComputeCommand<T>, ComputeResponse<T>> for PartitionedComputeState<T>
414where
415 T: ComputeControllerTimestamp,
416{
417 fn split_command(&mut self, command: ComputeCommand<T>) -> Vec<Option<ComputeCommand<T>>> {
418 self.observe_command(&command);
419
420 match command {
424 command @ ComputeCommand::Hello { .. }
425 | command @ ComputeCommand::UpdateConfiguration(_) => {
426 vec![Some(command); self.parts]
427 }
428 command => {
429 let mut r = vec![None; self.parts];
430 r[0] = Some(command);
431 r
432 }
433 }
434 }
435
436 fn absorb_response(
437 &mut self,
438 shard_id: usize,
439 message: ComputeResponse<T>,
440 ) -> Option<Result<ComputeResponse<T>, anyhow::Error>> {
441 let response = match message {
442 ComputeResponse::Frontiers(id, frontiers) => {
443 self.absorb_frontiers(shard_id, id, frontiers)
444 }
445 ComputeResponse::PeekResponse(uuid, response, otel_ctx) => {
446 self.absorb_peek_response(shard_id, uuid, response, otel_ctx)
447 }
448 ComputeResponse::SubscribeResponse(id, response) => {
449 self.absorb_subscribe_response(id, response)
450 }
451 ComputeResponse::CopyToResponse(id, response) => {
452 self.absorb_copy_to_response(shard_id, id, response)
453 }
454 response @ ComputeResponse::Status(_) => {
455 Some(response)
457 }
458 };
459
460 response.map(Ok)
461 }
462}
463
464#[derive(Debug)]
469struct TrackedFrontiers<T> {
470 write_frontier: (MutableAntichain<T>, Vec<Antichain<T>>),
472 input_frontier: (MutableAntichain<T>, Vec<Antichain<T>>),
474 output_frontier: (MutableAntichain<T>, Vec<Antichain<T>>),
476}
477
478impl<T> TrackedFrontiers<T>
479where
480 T: timely::progress::Timestamp + Lattice,
481{
482 fn new(parts: usize) -> Self {
484 #[allow(clippy::as_conversions)]
486 let parts_diff = parts as i64;
487
488 let mut frontier = MutableAntichain::new();
489 frontier.update_iter([(T::minimum(), parts_diff)]);
490 let part_frontiers = vec![Antichain::from_elem(T::minimum()); parts];
491 let frontier_entry = (frontier, part_frontiers);
492
493 Self {
494 write_frontier: frontier_entry.clone(),
495 input_frontier: frontier_entry.clone(),
496 output_frontier: frontier_entry,
497 }
498 }
499
500 fn all_empty(&self) -> bool {
502 self.write_frontier.0.frontier().is_empty()
503 && self.input_frontier.0.frontier().is_empty()
504 && self.output_frontier.0.frontier().is_empty()
505 }
506
507 fn update_write_frontier(
511 &mut self,
512 shard_id: usize,
513 new_shard_frontier: &Antichain<T>,
514 ) -> Option<Antichain<T>> {
515 Self::update_frontier(&mut self.write_frontier, shard_id, new_shard_frontier)
516 }
517
518 fn update_input_frontier(
522 &mut self,
523 shard_id: usize,
524 new_shard_frontier: &Antichain<T>,
525 ) -> Option<Antichain<T>> {
526 Self::update_frontier(&mut self.input_frontier, shard_id, new_shard_frontier)
527 }
528
529 fn update_output_frontier(
533 &mut self,
534 shard_id: usize,
535 new_shard_frontier: &Antichain<T>,
536 ) -> Option<Antichain<T>> {
537 Self::update_frontier(&mut self.output_frontier, shard_id, new_shard_frontier)
538 }
539
540 fn update_frontier(
542 entry: &mut (MutableAntichain<T>, Vec<Antichain<T>>),
543 shard_id: usize,
544 new_shard_frontier: &Antichain<T>,
545 ) -> Option<Antichain<T>> {
546 let (frontier, shard_frontiers) = entry;
547
548 let old_frontier = frontier.frontier().to_owned();
549 let shard_frontier = &mut shard_frontiers[shard_id];
550 frontier.update_iter(shard_frontier.iter().map(|t| (t.clone(), -1)));
551 shard_frontier.join_assign(new_shard_frontier);
552 frontier.update_iter(shard_frontier.iter().map(|t| (t.clone(), 1)));
553
554 let new_frontier = frontier.frontier();
555
556 if PartialOrder::less_than(&old_frontier.borrow(), &new_frontier) {
557 Some(new_frontier.to_owned())
558 } else {
559 None
560 }
561 }
562}
563
564#[derive(Debug)]
565struct PendingSubscribe<T> {
566 frontiers: MutableAntichain<T>,
568 stashed_updates: Result<Vec<(T, Row, Diff)>, String>,
570 stashed_result_size: usize,
572 dropped: bool,
576}
577
578impl<T: ComputeControllerTimestamp> PendingSubscribe<T> {
579 fn new(parts: usize) -> Self {
580 let mut frontiers = MutableAntichain::new();
581 #[allow(clippy::as_conversions)]
583 frontiers.update_iter([(T::minimum(), parts as i64)]);
584
585 Self {
586 frontiers,
587 stashed_updates: Ok(Vec::new()),
588 stashed_result_size: 0,
589 dropped: false,
590 }
591 }
592
593 fn stash(&mut self, new_updates: Result<Vec<(T, Row, Diff)>, String>, max_result_size: u64) {
598 match (&mut self.stashed_updates, new_updates) {
599 (Err(_), _) => {
600 }
603 (_, Err(text)) => {
604 self.stashed_updates = Err(text);
605 }
606 (Ok(stashed), Ok(new)) => {
607 let new_size: usize = new.iter().map(|(_, row, _)| row.byte_len()).sum();
608 self.stashed_result_size += new_size;
609
610 if self.stashed_result_size > max_result_size.cast_into() {
611 self.stashed_updates = Err(format!(
612 "total result exceeds max size of {}",
613 ByteSize::b(max_result_size)
614 ));
615 } else {
616 stashed.extend(new);
617 }
618 }
619 }
620 }
621}
622
623fn merge_peek_responses(
625 resp1: PeekResponse,
626 resp2: PeekResponse,
627 max_result_size: u64,
628) -> PeekResponse {
629 use PeekResponse::*;
630
631 let (resp1, resp2) = match (resp1, resp2) {
633 (Canceled, _) | (_, Canceled) => return Canceled,
634 (Error(e), _) | (_, Error(e)) => return Error(e),
635 resps => resps,
636 };
637
638 let total_byte_len = resp1.inline_byte_len() + resp2.inline_byte_len();
639 if total_byte_len > max_result_size.cast_into() {
640 let err = format!(
643 "total result exceeds max size of {}",
644 ByteSize::b(max_result_size)
645 );
646 return Error(err);
647 }
648
649 match (resp1, resp2) {
650 (Rows(mut rows1), Rows(rows2)) => {
651 rows1.merge(&rows2);
652 Rows(rows1)
653 }
654 (Rows(rows), Stashed(mut stashed)) | (Stashed(mut stashed), Rows(rows)) => {
655 stashed.inline_rows.merge(&rows);
656 Stashed(stashed)
657 }
658 (Stashed(stashed1), Stashed(stashed2)) => {
659 let StashedPeekResponse {
662 num_rows_batches: num_rows_batches1,
663 encoded_size_bytes: encoded_size_bytes1,
664 relation_desc: relation_desc1,
665 shard_id: shard_id1,
666 batches: mut batches1,
667 inline_rows: mut inline_rows1,
668 } = *stashed1;
669 let StashedPeekResponse {
670 num_rows_batches: num_rows_batches2,
671 encoded_size_bytes: encoded_size_bytes2,
672 relation_desc: relation_desc2,
673 shard_id: shard_id2,
674 batches: mut batches2,
675 inline_rows: inline_rows2,
676 } = *stashed2;
677
678 if shard_id1 != shard_id2 {
679 soft_panic_or_log!(
680 "shard IDs of stashed responses do not match: \
681 {shard_id1} != {shard_id2}"
682 );
683 return Error("internal error".into());
684 }
685 if relation_desc1 != relation_desc2 {
686 soft_panic_or_log!(
687 "relation descs of stashed responses do not match: \
688 {relation_desc1:?} != {relation_desc2:?}"
689 );
690 return Error("internal error".into());
691 }
692
693 batches1.append(&mut batches2);
694 inline_rows1.merge(&inline_rows2);
695
696 Stashed(Box::new(StashedPeekResponse {
697 num_rows_batches: num_rows_batches1 + num_rows_batches2,
698 encoded_size_bytes: encoded_size_bytes1 + encoded_size_bytes2,
699 relation_desc: relation_desc1,
700 shard_id: shard_id1,
701 batches: batches1,
702 inline_rows: inline_rows1,
703 }))
704 }
705 _ => unreachable!("handled above"),
706 }
707}