1use std::collections::{BTreeMap, BTreeSet};
13use std::mem;
14
15use async_trait::async_trait;
16use bytesize::ByteSize;
17use differential_dataflow::lattice::Lattice;
18use mz_expr::row::RowCollection;
19use mz_ore::cast::CastInto;
20use mz_ore::soft_panic_or_log;
21use mz_ore::tracing::OpenTelemetryContext;
22use mz_repr::{GlobalId, Timestamp, UpdateCollection};
23use mz_service::client::{GenericClient, Partitionable, PartitionedState};
24use timely::PartialOrder;
25use timely::progress::frontier::{Antichain, MutableAntichain};
26use uuid::Uuid;
27
28use crate::protocol::command::ComputeCommand;
29use crate::protocol::response::{
30 ComputeResponse, CopyToResponse, FrontiersResponse, PeekResponse, StashedPeekResponse,
31 SubscribeBatch, SubscribeResponse,
32};
33
34pub trait ComputeClient: GenericClient<ComputeCommand, ComputeResponse> {}
36
37impl<C> ComputeClient for C where C: GenericClient<ComputeCommand, ComputeResponse> {}
38
39#[async_trait]
40impl GenericClient<ComputeCommand, ComputeResponse> for Box<dyn ComputeClient> {
41 async fn send(&mut self, cmd: ComputeCommand) -> Result<(), anyhow::Error> {
42 (**self).send(cmd).await
43 }
44
45 async fn recv(&mut self) -> Result<Option<ComputeResponse>, anyhow::Error> {
51 (**self).recv().await
53 }
54}
55
56#[derive(Debug)]
79pub struct PartitionedComputeState {
80 parts: usize,
82 max_result_size: u64,
86 frontiers: BTreeMap<GlobalId, TrackedFrontiers>,
99 peek_responses: BTreeMap<Uuid, (PeekResponse, BTreeSet<usize>)>,
110 copy_to_responses: BTreeMap<GlobalId, (CopyToResponse, BTreeSet<usize>)>,
121 pending_subscribes: BTreeMap<GlobalId, PendingSubscribe>,
145}
146
147impl Partitionable<ComputeCommand, ComputeResponse> for (ComputeCommand, ComputeResponse) {
148 type PartitionedState = PartitionedComputeState;
149
150 fn new(parts: usize) -> PartitionedComputeState {
151 PartitionedComputeState {
152 parts,
153 max_result_size: u64::MAX,
154 frontiers: BTreeMap::new(),
155 peek_responses: BTreeMap::new(),
156 pending_subscribes: BTreeMap::new(),
157 copy_to_responses: BTreeMap::new(),
158 }
159 }
160}
161
162impl PartitionedComputeState {
163 pub fn observe_command(&mut self, command: &ComputeCommand) {
165 match command {
166 ComputeCommand::UpdateConfiguration(config) => {
167 if let Some(max_result_size) = config.max_result_size {
168 self.max_result_size = max_result_size;
169 }
170 }
171 _ => {
172 }
175 }
176 }
177
178 fn absorb_frontiers(
180 &mut self,
181 shard_id: usize,
182 collection_id: GlobalId,
183 frontiers: FrontiersResponse,
184 ) -> Option<ComputeResponse> {
185 let tracked = self
186 .frontiers
187 .entry(collection_id)
188 .or_insert_with(|| TrackedFrontiers::new(self.parts));
189
190 let write_frontier = frontiers
191 .write_frontier
192 .and_then(|f| tracked.update_write_frontier(shard_id, &f));
193 let input_frontier = frontiers
194 .input_frontier
195 .and_then(|f| tracked.update_input_frontier(shard_id, &f));
196 let output_frontier = frontiers
197 .output_frontier
198 .and_then(|f| tracked.update_output_frontier(shard_id, &f));
199
200 let frontiers = FrontiersResponse {
201 write_frontier,
202 input_frontier,
203 output_frontier,
204 };
205 let result = frontiers
206 .has_updates()
207 .then_some(ComputeResponse::Frontiers(collection_id, frontiers));
208
209 if tracked.all_empty() {
210 self.frontiers.remove(&collection_id);
213 }
214
215 result
216 }
217
218 fn absorb_peek_response(
220 &mut self,
221 shard_id: usize,
222 uuid: Uuid,
223 response: PeekResponse,
224 otel_ctx: OpenTelemetryContext,
225 ) -> Option<ComputeResponse> {
226 let (merged, ready_shards) = self.peek_responses.entry(uuid).or_insert((
227 PeekResponse::Rows(vec![RowCollection::default()]),
228 BTreeSet::new(),
229 ));
230
231 let first = ready_shards.insert(shard_id);
232 assert!(first, "duplicate peek response");
233
234 let resp1 = mem::replace(merged, PeekResponse::Canceled);
235 *merged = merge_peek_responses(resp1, response, self.max_result_size);
236
237 if ready_shards.len() == self.parts {
238 let (response, _) = self.peek_responses.remove(&uuid).unwrap();
239 Some(ComputeResponse::PeekResponse(uuid, response, otel_ctx))
240 } else {
241 None
242 }
243 }
244
245 fn absorb_copy_to_response(
247 &mut self,
248 shard_id: usize,
249 copyto_id: GlobalId,
250 response: CopyToResponse,
251 ) -> Option<ComputeResponse> {
252 use CopyToResponse::*;
253
254 let (merged, ready_shards) = self
255 .copy_to_responses
256 .entry(copyto_id)
257 .or_insert((CopyToResponse::RowCount(0), BTreeSet::new()));
258
259 let first = ready_shards.insert(shard_id);
260 assert!(first, "duplicate copy-to response");
261
262 let resp1 = mem::replace(merged, Dropped);
263 *merged = match (resp1, response) {
264 (Dropped, _) | (_, Dropped) => Dropped,
265 (Error(e), _) | (_, Error(e)) => Error(e),
266 (RowCount(r1), RowCount(r2)) => RowCount(r1 + r2),
267 };
268
269 if ready_shards.len() == self.parts {
270 let (response, _) = self.copy_to_responses.remove(©to_id).unwrap();
271 Some(ComputeResponse::CopyToResponse(copyto_id, response))
272 } else {
273 None
274 }
275 }
276
277 fn absorb_subscribe_response(
279 &mut self,
280 subscribe_id: GlobalId,
281 response: SubscribeResponse,
282 ) -> Option<ComputeResponse> {
283 let tracked = self
284 .pending_subscribes
285 .entry(subscribe_id)
286 .or_insert_with(|| PendingSubscribe::new(self.parts));
287
288 let emit_response = match response {
289 SubscribeResponse::Batch(batch) => {
290 let frontiers = &mut tracked.frontiers;
291 let old_frontier = frontiers.frontier().to_owned();
292 frontiers.update_iter(batch.lower.into_iter().map(|t| (t, -1)));
293 frontiers.update_iter(batch.upper.into_iter().map(|t| (t, 1)));
294 let new_frontier = frontiers.frontier().to_owned();
295
296 tracked.stash(batch.updates, self.max_result_size);
297
298 if old_frontier != new_frontier && !tracked.dropped {
302 let updates = match &mut tracked.stashed_updates {
303 Ok(stashed_updates) => {
304 let mut ship = vec![];
306 let mut keep = vec![];
307 for collection in stashed_updates.drain(..) {
308 let partition_point = collection
309 .times()
310 .partition_point(|t| !new_frontier.less_equal(t));
311 let (ship_coll, keep_coll) = collection.split_at(partition_point);
312 if ship_coll.len() > 0 {
313 ship.push(ship_coll);
314 }
315 if keep_coll.len() > 0 {
316 keep.push(keep_coll);
317 }
318 }
319 tracked.stashed_result_size = keep.iter().map(|c| c.byte_len()).sum();
320 tracked.stashed_updates = Ok(keep);
321 Ok(ship)
322 }
323 Err(text) => Err(text.clone()),
324 };
325 Some(ComputeResponse::SubscribeResponse(
326 subscribe_id,
327 SubscribeResponse::Batch(SubscribeBatch {
328 lower: old_frontier,
329 upper: new_frontier,
330 updates,
331 }),
332 ))
333 } else {
334 None
335 }
336 }
337 SubscribeResponse::DroppedAt(frontier) => {
338 tracked
339 .frontiers
340 .update_iter(frontier.iter().map(|t| (t.clone(), -1)));
341
342 if tracked.dropped {
343 None
344 } else {
345 tracked.dropped = true;
346 Some(ComputeResponse::SubscribeResponse(
347 subscribe_id,
348 SubscribeResponse::DroppedAt(frontier),
349 ))
350 }
351 }
352 };
353
354 if tracked.frontiers.frontier().is_empty() {
355 self.pending_subscribes.remove(&subscribe_id);
358 }
359
360 emit_response
361 }
362}
363
364impl PartitionedState<ComputeCommand, ComputeResponse> for PartitionedComputeState {
365 fn split_command(&mut self, command: ComputeCommand) -> Vec<Option<ComputeCommand>> {
366 self.observe_command(&command);
367
368 match command {
372 command @ ComputeCommand::Hello { .. }
373 | command @ ComputeCommand::UpdateConfiguration(_) => {
374 vec![Some(command); self.parts]
375 }
376 command => {
377 let mut r = vec![None; self.parts];
378 r[0] = Some(command);
379 r
380 }
381 }
382 }
383
384 fn absorb_response(
385 &mut self,
386 shard_id: usize,
387 message: ComputeResponse,
388 ) -> Option<Result<ComputeResponse, anyhow::Error>> {
389 let response = match message {
390 ComputeResponse::Frontiers(id, frontiers) => {
391 self.absorb_frontiers(shard_id, id, frontiers)
392 }
393 ComputeResponse::PeekResponse(uuid, response, otel_ctx) => {
394 self.absorb_peek_response(shard_id, uuid, response, otel_ctx)
395 }
396 ComputeResponse::SubscribeResponse(id, response) => {
397 self.absorb_subscribe_response(id, response)
398 }
399 ComputeResponse::CopyToResponse(id, response) => {
400 self.absorb_copy_to_response(shard_id, id, response)
401 }
402 response @ ComputeResponse::Status(_) => {
403 Some(response)
405 }
406 };
407
408 response.map(Ok)
409 }
410}
411
412#[derive(Debug)]
417struct TrackedFrontiers {
418 write_frontier: (MutableAntichain<Timestamp>, Vec<Antichain<Timestamp>>),
420 input_frontier: (MutableAntichain<Timestamp>, Vec<Antichain<Timestamp>>),
422 output_frontier: (MutableAntichain<Timestamp>, Vec<Antichain<Timestamp>>),
424}
425
426impl TrackedFrontiers {
427 fn new(parts: usize) -> Self {
429 #[allow(clippy::as_conversions)]
431 let parts_diff = parts as i64;
432
433 let mut frontier = MutableAntichain::new();
434 frontier.update_iter([(Timestamp::MIN, parts_diff)]);
435 let part_frontiers = vec![Antichain::from_elem(Timestamp::MIN); parts];
436 let frontier_entry = (frontier, part_frontiers);
437
438 Self {
439 write_frontier: frontier_entry.clone(),
440 input_frontier: frontier_entry.clone(),
441 output_frontier: frontier_entry,
442 }
443 }
444
445 fn all_empty(&self) -> bool {
447 self.write_frontier.0.frontier().is_empty()
448 && self.input_frontier.0.frontier().is_empty()
449 && self.output_frontier.0.frontier().is_empty()
450 }
451
452 fn update_write_frontier(
456 &mut self,
457 shard_id: usize,
458 new_shard_frontier: &Antichain<Timestamp>,
459 ) -> Option<Antichain<Timestamp>> {
460 Self::update_frontier(&mut self.write_frontier, shard_id, new_shard_frontier)
461 }
462
463 fn update_input_frontier(
467 &mut self,
468 shard_id: usize,
469 new_shard_frontier: &Antichain<Timestamp>,
470 ) -> Option<Antichain<Timestamp>> {
471 Self::update_frontier(&mut self.input_frontier, shard_id, new_shard_frontier)
472 }
473
474 fn update_output_frontier(
478 &mut self,
479 shard_id: usize,
480 new_shard_frontier: &Antichain<Timestamp>,
481 ) -> Option<Antichain<Timestamp>> {
482 Self::update_frontier(&mut self.output_frontier, shard_id, new_shard_frontier)
483 }
484
485 fn update_frontier(
487 entry: &mut (MutableAntichain<Timestamp>, Vec<Antichain<Timestamp>>),
488 shard_id: usize,
489 new_shard_frontier: &Antichain<Timestamp>,
490 ) -> Option<Antichain<Timestamp>> {
491 let (frontier, shard_frontiers) = entry;
492
493 let old_frontier = frontier.frontier().to_owned();
494 let shard_frontier = &mut shard_frontiers[shard_id];
495 frontier.update_iter(shard_frontier.iter().map(|t| (t.clone(), -1)));
496 shard_frontier.join_assign(new_shard_frontier);
497 frontier.update_iter(shard_frontier.iter().map(|t| (t.clone(), 1)));
498
499 let new_frontier = frontier.frontier();
500
501 if PartialOrder::less_than(&old_frontier.borrow(), &new_frontier) {
502 Some(new_frontier.to_owned())
503 } else {
504 None
505 }
506 }
507}
508
509#[derive(Debug)]
510struct PendingSubscribe {
511 frontiers: MutableAntichain<Timestamp>,
513 stashed_updates: Result<Vec<UpdateCollection>, String>,
515 stashed_result_size: usize,
517 dropped: bool,
521}
522
523impl PendingSubscribe {
524 fn new(parts: usize) -> Self {
525 let mut frontiers = MutableAntichain::new();
526 #[allow(clippy::as_conversions)]
528 frontiers.update_iter([(Timestamp::MIN, parts as i64)]);
529
530 Self {
531 frontiers,
532 stashed_updates: Ok(Vec::new()),
533 stashed_result_size: 0,
534 dropped: false,
535 }
536 }
537
538 fn stash(&mut self, new_updates: Result<Vec<UpdateCollection>, String>, max_result_size: u64) {
543 match (&mut self.stashed_updates, new_updates) {
544 (Err(_), _) => {
545 }
548 (_, Err(text)) => {
549 self.stashed_updates = Err(text);
550 }
551 (Ok(stashed), Ok(new)) => {
552 let new_size: usize = new.iter().map(|coll| coll.byte_len()).sum();
553 self.stashed_result_size += new_size;
554
555 if self.stashed_result_size > max_result_size.cast_into() {
556 self.stashed_updates = Err(format!(
557 "total result exceeds max size of {}",
558 ByteSize::b(max_result_size)
559 ));
560 } else {
561 stashed.extend(new);
562 }
563 }
564 }
565 }
566}
567
568fn merge_peek_responses(
570 resp1: PeekResponse,
571 resp2: PeekResponse,
572 max_result_size: u64,
573) -> PeekResponse {
574 use PeekResponse::*;
575
576 let (resp1, resp2) = match (resp1, resp2) {
578 (Canceled, _) | (_, Canceled) => return Canceled,
579 (Error(e), _) | (_, Error(e)) => return Error(e),
580 resps => resps,
581 };
582
583 let total_byte_len = resp1.inline_byte_len() + resp2.inline_byte_len();
584 if total_byte_len > max_result_size.cast_into() {
585 let err = format!(
588 "total result exceeds max size of {}",
589 ByteSize::b(max_result_size)
590 );
591 return Error(err);
592 }
593
594 match (resp1, resp2) {
595 (Rows(mut rows1), Rows(rows2)) => {
596 rows1.extend(rows2);
597 Rows(rows1)
598 }
599 (Rows(rows), Stashed(mut stashed)) | (Stashed(mut stashed), Rows(rows)) => {
600 stashed.inline_rows.extend(rows);
601 Stashed(stashed)
602 }
603 (Stashed(stashed1), Stashed(stashed2)) => {
604 let StashedPeekResponse {
607 num_rows_batches: num_rows_batches1,
608 encoded_size_bytes: encoded_size_bytes1,
609 relation_desc: relation_desc1,
610 shard_id: shard_id1,
611 batches: mut batches1,
612 inline_rows: mut inline_rows1,
613 } = *stashed1;
614 let StashedPeekResponse {
615 num_rows_batches: num_rows_batches2,
616 encoded_size_bytes: encoded_size_bytes2,
617 relation_desc: relation_desc2,
618 shard_id: shard_id2,
619 batches: mut batches2,
620 inline_rows: inline_rows2,
621 } = *stashed2;
622
623 if shard_id1 != shard_id2 {
624 soft_panic_or_log!(
625 "shard IDs of stashed responses do not match: \
626 {shard_id1} != {shard_id2}"
627 );
628 return Error("internal error".into());
629 }
630 if relation_desc1 != relation_desc2 {
631 soft_panic_or_log!(
632 "relation descs of stashed responses do not match: \
633 {relation_desc1:?} != {relation_desc2:?}"
634 );
635 return Error("internal error".into());
636 }
637
638 batches1.append(&mut batches2);
639 inline_rows1.extend(inline_rows2);
640
641 Stashed(Box::new(StashedPeekResponse {
642 num_rows_batches: num_rows_batches1 + num_rows_batches2,
643 encoded_size_bytes: encoded_size_bytes1 + encoded_size_bytes2,
644 relation_desc: relation_desc1,
645 shard_id: shard_id1,
646 batches: batches1,
647 inline_rows: inline_rows1,
648 }))
649 }
650 _ => unreachable!("handled above"),
651 }
652}