mz_compute_client/
service.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10// Tonic generates code that violates clippy lints.
11// TODO: Remove this once tonic does not produce this code anymore.
12#![allow(clippy::as_conversions, clippy::clone_on_ref_ptr)]
13
14//! Compute layer client and server.
15
16use std::collections::BTreeMap;
17
18use async_trait::async_trait;
19use bytesize::ByteSize;
20use differential_dataflow::consolidation::consolidate_updates;
21use differential_dataflow::lattice::Lattice;
22use mz_expr::row::RowCollection;
23use mz_ore::assert_none;
24use mz_ore::cast::CastFrom;
25use mz_repr::{Diff, GlobalId, Row};
26use mz_service::client::{GenericClient, Partitionable, PartitionedState};
27use mz_service::grpc::{GrpcClient, GrpcServer, ProtoServiceTypes, ResponseStream};
28use timely::PartialOrder;
29use timely::progress::frontier::{Antichain, MutableAntichain};
30use tonic::{Request, Status, Streaming};
31use uuid::Uuid;
32
33use crate::controller::ComputeControllerTimestamp;
34use crate::metrics::ReplicaMetrics;
35use crate::protocol::command::{ComputeCommand, ProtoComputeCommand};
36use crate::protocol::response::{
37    ComputeResponse, CopyToResponse, FrontiersResponse, PeekResponse, ProtoComputeResponse,
38    SubscribeBatch, SubscribeResponse,
39};
40use crate::service::proto_compute_server::ProtoCompute;
41
42include!(concat!(env!("OUT_DIR"), "/mz_compute_client.service.rs"));
43
44/// A client to a compute server.
45pub trait ComputeClient<T = mz_repr::Timestamp>:
46    GenericClient<ComputeCommand<T>, ComputeResponse<T>>
47{
48}
49
50impl<C, T> ComputeClient<T> for C where C: GenericClient<ComputeCommand<T>, ComputeResponse<T>> {}
51
52#[async_trait]
53impl<T: Send> GenericClient<ComputeCommand<T>, ComputeResponse<T>> for Box<dyn ComputeClient<T>> {
54    async fn send(&mut self, cmd: ComputeCommand<T>) -> Result<(), anyhow::Error> {
55        (**self).send(cmd).await
56    }
57
58    /// # Cancel safety
59    ///
60    /// This method is cancel safe. If `recv` is used as the event in a [`tokio::select!`]
61    /// statement and some other branch completes first, it is guaranteed that no messages were
62    /// received by this client.
63    async fn recv(&mut self) -> Result<Option<ComputeResponse<T>>, anyhow::Error> {
64        // `GenericClient::recv` is required to be cancel safe.
65        (**self).recv().await
66    }
67}
68
69/// TODO(database-issues#7533): Add documentation.
70#[derive(Debug, Clone)]
71pub enum ComputeProtoServiceTypes {}
72
73impl ProtoServiceTypes for ComputeProtoServiceTypes {
74    type PC = ProtoComputeCommand;
75    type PR = ProtoComputeResponse;
76    type STATS = ReplicaMetrics;
77    const URL: &'static str = "/mz_compute_client.service.ProtoCompute/CommandResponseStream";
78}
79
80/// TODO(database-issues#7533): Add documentation.
81pub type ComputeGrpcClient = GrpcClient<ComputeProtoServiceTypes>;
82
83#[async_trait]
84impl<F, G> ProtoCompute for GrpcServer<F>
85where
86    F: Fn() -> G + Send + Sync + 'static,
87    G: ComputeClient + 'static,
88{
89    type CommandResponseStreamStream = ResponseStream<ProtoComputeResponse>;
90
91    async fn command_response_stream(
92        &self,
93        request: Request<Streaming<ProtoComputeCommand>>,
94    ) -> Result<tonic::Response<Self::CommandResponseStreamStream>, Status> {
95        self.forward_bidi_stream(request).await
96    }
97}
98
99/// Maintained state for partitioned compute clients.
100///
101/// This helper type unifies the responses of multiple partitioned workers in order to present as a
102/// single worker:
103///
104///   * It emits `Frontiers` responses reporting the minimum/meet of frontiers reported by the
105///     individual workers.
106///   * It emits `PeekResponse`s and `SubscribeResponse`s reporting the union of the responses
107///     received from the workers.
108///
109/// In the compute communication stack, this client is instantiated several times:
110///
111///   * One instance on the controller side, dispatching between cluster processes.
112///   * One instance in each cluster process, dispatching between timely worker threads.
113///
114/// Note that because compute commands, except `CreateTimely` and `UpdateConfiguration`, are only
115/// sent to the first process, the cluster-side instances of `PartitionedComputeState` are not
116/// guaranteed to see all compute commands. Or more specifically: The instance running inside
117/// process 0 sees all commands, whereas the instances running inside the other processes only see
118/// `CreateTimely` and `UpdateConfiguration`. The `PartitionedComputeState` implementation must be
119/// able to cope with this limited visiblity. It does so by performing most of its state management
120/// based on observed compute responses rather than commands.
121#[derive(Debug)]
122pub struct PartitionedComputeState<T> {
123    /// Number of partitions the state machine represents.
124    parts: usize,
125    /// The maximum result size this state machine can return.
126    ///
127    /// This is updated upon receiving [`ComputeCommand::UpdateConfiguration`]s.
128    max_result_size: u64,
129    /// Tracked frontiers for indexes and sinks.
130    ///
131    /// Frontier tracking for a collection is initialized when the first `Frontiers` response
132    /// for that collection is received. Frontier tracking is ceased when all shards have reported
133    /// advancement to the empty frontier for all frontier kinds.
134    ///
135    /// The compute protocol requires that shards always emit `Frontiers` responses reporting empty
136    /// frontiers for all frontier kinds when a collection is dropped. It further requires that no
137    /// further `Frontier` responses are emitted for a collection after the empty frontiers were
138    /// reported. These properties ensure that a) we always cease frontier tracking for collections
139    /// that have been dropped and b) frontier tracking for a collection is not re-initialized
140    /// after it was ceased.
141    frontiers: BTreeMap<GlobalId, TrackedFrontiers<T>>,
142    /// Pending responses for a peek; returnable once all are available.
143    ///
144    /// Tracking of responses for a peek is initialized when the first `PeekResponse` for that peek
145    /// is received. Once all shards have provided a `PeekResponse`, a unified peek response is
146    /// emitted and the peek tracking state is dropped again.
147    ///
148    /// The compute protocol requires that exactly one response is emitted for each peek. This
149    /// property ensures that a) we can eventually drop the tracking state maintained for a peek
150    /// and b) we won't re-initialize tracking for a peek we have already served.
151    peek_responses: BTreeMap<Uuid, BTreeMap<usize, PeekResponse>>,
152    /// Pending responses for a copy to; returnable once all are available.
153    ///
154    /// Tracking of responses for a COPY TO is initialized when the first `CopyResponse` for that command
155    /// is received. Once all shards have provided a `CopyResponse`, a unified copy response is
156    /// emitted and the copy_to tracking state is dropped again.
157    ///
158    /// The compute protocol requires that exactly one response is emitted for each COPY TO command. This
159    /// property ensures that a) we can eventually drop the tracking state maintained for a copy
160    /// and b) we won't re-initialize tracking for a copy we have already served.
161    copy_to_responses: BTreeMap<GlobalId, BTreeMap<usize, CopyToResponse>>,
162    /// Tracks in-progress `SUBSCRIBE`s, and the stashed rows we are holding back until their
163    /// timestamps are complete.
164    ///
165    /// The updates may be `Err` if any of the batches have reported an error, in which case the
166    /// subscribe is permanently borked.
167    ///
168    /// Tracking of a subscribe is initialized when the first `SubscribeResponse` for that
169    /// subscribe is received. Once all shards have emitted an "end-of-subscribe" response the
170    /// subscribe tracking state is dropped again.
171    ///
172    /// The compute protocol requires that for a subscribe that shuts down an end-of-subscribe
173    /// response is emitted:
174    ///
175    ///   * Either a `Batch` response reporting advancement to the empty frontier...
176    ///   * ... or a `DroppedAt` response reporting that the subscribe was dropped before
177    ///     completing.
178    ///
179    /// The compute protocol further requires that no further `SubscribeResponse`s are emitted for
180    /// a subscribe after an end-of-subscribe was reported.
181    ///
182    /// These two properties ensure that a) once a subscribe has shut down, we can eventually drop
183    /// the tracking state maintained for it and b) we won't re-initialize tracking for a subscribe
184    /// we have already dropped.
185    pending_subscribes: BTreeMap<GlobalId, PendingSubscribe<T>>,
186}
187
188impl<T> Partitionable<ComputeCommand<T>, ComputeResponse<T>>
189    for (ComputeCommand<T>, ComputeResponse<T>)
190where
191    T: ComputeControllerTimestamp,
192{
193    type PartitionedState = PartitionedComputeState<T>;
194
195    fn new(parts: usize) -> PartitionedComputeState<T> {
196        PartitionedComputeState {
197            parts,
198            max_result_size: u64::MAX,
199            frontiers: BTreeMap::new(),
200            peek_responses: BTreeMap::new(),
201            pending_subscribes: BTreeMap::new(),
202            copy_to_responses: BTreeMap::new(),
203        }
204    }
205}
206
207impl<T> PartitionedComputeState<T>
208where
209    T: ComputeControllerTimestamp,
210{
211    fn reset(&mut self) {
212        let PartitionedComputeState {
213            parts: _,
214            max_result_size: _,
215            frontiers,
216            peek_responses,
217            pending_subscribes,
218            copy_to_responses,
219        } = self;
220        frontiers.clear();
221        peek_responses.clear();
222        pending_subscribes.clear();
223        copy_to_responses.clear();
224    }
225
226    /// Observes commands that move past, and prepares state for responses.
227    pub fn observe_command(&mut self, command: &ComputeCommand<T>) {
228        match command {
229            ComputeCommand::CreateTimely { .. } => self.reset(),
230            ComputeCommand::UpdateConfiguration(config) => {
231                if let Some(max_result_size) = config.max_result_size {
232                    self.max_result_size = max_result_size;
233                }
234            }
235            _ => {
236                // We are not guaranteed to observe other compute commands. We
237                // must therefore not add any logic here that relies on doing so.
238            }
239        }
240    }
241}
242
243impl<T> PartitionedState<ComputeCommand<T>, ComputeResponse<T>> for PartitionedComputeState<T>
244where
245    T: ComputeControllerTimestamp,
246{
247    fn split_command(&mut self, command: ComputeCommand<T>) -> Vec<Option<ComputeCommand<T>>> {
248        self.observe_command(&command);
249
250        // As specified by the compute protocol:
251        //  * Forward `CreateTimely` and `UpdateConfiguration` commands to all shards.
252        //  * Forward all other commands to the first shard only.
253        match command {
254            ComputeCommand::CreateTimely { config, epoch } => {
255                let timely_cmds = config.split_command(self.parts);
256
257                timely_cmds
258                    .into_iter()
259                    .map(|config| {
260                        Some(ComputeCommand::CreateTimely {
261                            config: Box::new(config),
262                            epoch,
263                        })
264                    })
265                    .collect()
266            }
267            command @ ComputeCommand::UpdateConfiguration(_) => {
268                vec![Some(command); self.parts]
269            }
270            command => {
271                let mut r = vec![None; self.parts];
272                r[0] = Some(command);
273                r
274            }
275        }
276    }
277
278    fn absorb_response(
279        &mut self,
280        shard_id: usize,
281        message: ComputeResponse<T>,
282    ) -> Option<Result<ComputeResponse<T>, anyhow::Error>> {
283        match message {
284            ComputeResponse::Frontiers(id, frontiers) => {
285                // Initialize frontier tracking state for this collection, if necessary.
286                let tracked = self
287                    .frontiers
288                    .entry(id)
289                    .or_insert_with(|| TrackedFrontiers::new(self.parts));
290
291                let write_frontier = frontiers
292                    .write_frontier
293                    .and_then(|f| tracked.update_write_frontier(shard_id, &f));
294                let input_frontier = frontiers
295                    .input_frontier
296                    .and_then(|f| tracked.update_input_frontier(shard_id, &f));
297                let output_frontier = frontiers
298                    .output_frontier
299                    .and_then(|f| tracked.update_output_frontier(shard_id, &f));
300
301                let frontiers = FrontiersResponse {
302                    write_frontier,
303                    input_frontier,
304                    output_frontier,
305                };
306                let result = frontiers
307                    .has_updates()
308                    .then_some(Ok(ComputeResponse::Frontiers(id, frontiers)));
309
310                if tracked.all_empty() {
311                    // All shards have reported advancement to the empty frontier, so we do not
312                    // expect further updates for this collection.
313                    self.frontiers.remove(&id);
314                }
315
316                result
317            }
318            ComputeResponse::PeekResponse(uuid, response, otel_ctx) => {
319                // Incorporate new peek responses; awaiting all responses.
320                let entry = self
321                    .peek_responses
322                    .entry(uuid)
323                    .or_insert_with(Default::default);
324                let novel = entry.insert(shard_id, response);
325                assert_none!(novel, "Duplicate peek response");
326                // We may be ready to respond.
327                if entry.len() == self.parts {
328                    let mut response = PeekResponse::Rows(RowCollection::default());
329                    for (_part, r) in std::mem::take(entry).into_iter() {
330                        response = match (response, r) {
331                            (_, PeekResponse::Canceled) => PeekResponse::Canceled,
332                            (PeekResponse::Canceled, _) => PeekResponse::Canceled,
333                            (_, PeekResponse::Error(e)) => PeekResponse::Error(e),
334                            (PeekResponse::Error(e), _) => PeekResponse::Error(e),
335                            (PeekResponse::Rows(mut rows), PeekResponse::Rows(other)) => {
336                                let total_byte_size =
337                                    rows.byte_len().saturating_add(other.byte_len());
338
339                                if total_byte_size > usize::cast_from(self.max_result_size) {
340                                    // Note: We match on this specific error message in tests
341                                    // so it's important that nothing else returns the same
342                                    // string.
343                                    let err = format!(
344                                        "total result exceeds max size of {}",
345                                        ByteSize::b(self.max_result_size)
346                                    );
347                                    PeekResponse::Error(err)
348                                } else {
349                                    rows.merge(&other);
350                                    PeekResponse::Rows(rows)
351                                }
352                            }
353                        };
354                    }
355                    self.peek_responses.remove(&uuid);
356                    // We take the otel_ctx from the last peek, but they should all be the same
357                    Some(Ok(ComputeResponse::PeekResponse(uuid, response, otel_ctx)))
358                } else {
359                    None
360                }
361            }
362            ComputeResponse::SubscribeResponse(id, response) => {
363                // Initialize tracking for this subscribe, if necessary.
364                let entry = self
365                    .pending_subscribes
366                    .entry(id)
367                    .or_insert_with(|| PendingSubscribe::new(self.parts));
368
369                let emit_response = match response {
370                    SubscribeResponse::Batch(batch) => {
371                        let frontiers = &mut entry.frontiers;
372                        let old_frontier = frontiers.frontier().to_owned();
373                        frontiers.update_iter(batch.lower.into_iter().map(|t| (t, -1)));
374                        frontiers.update_iter(batch.upper.into_iter().map(|t| (t, 1)));
375                        let new_frontier = frontiers.frontier().to_owned();
376
377                        match (&mut entry.stashed_updates, batch.updates) {
378                            (Err(_), _) => {
379                                // Subscribe is borked; nothing to do.
380                                // TODO: Consider refreshing error?
381                            }
382                            (_, Err(text)) => {
383                                entry.stashed_updates = Err(text);
384                            }
385                            (Ok(stashed_updates), Ok(updates)) => {
386                                stashed_updates.extend(updates);
387                            }
388                        }
389
390                        // If the frontier has advanced, it is time to announce subscribe progress.
391                        // Unless we have already announced that the subscribe has been dropped, in
392                        // which case we must keep quiet.
393                        if old_frontier != new_frontier && !entry.dropped {
394                            let updates = match &mut entry.stashed_updates {
395                                Ok(stashed_updates) => {
396                                    // The compute protocol requires us to only send out
397                                    // consolidated batches.
398                                    consolidate_updates(stashed_updates);
399
400                                    let mut ship = Vec::new();
401                                    let mut keep = Vec::new();
402                                    for (time, data, diff) in stashed_updates.drain(..) {
403                                        if new_frontier.less_equal(&time) {
404                                            keep.push((time, data, diff));
405                                        } else {
406                                            ship.push((time, data, diff));
407                                        }
408                                    }
409                                    entry.stashed_updates = Ok(keep);
410                                    Ok(ship)
411                                }
412                                Err(text) => Err(text.clone()),
413                            };
414                            Some(Ok(ComputeResponse::SubscribeResponse(
415                                id,
416                                SubscribeResponse::Batch(SubscribeBatch {
417                                    lower: old_frontier,
418                                    upper: new_frontier,
419                                    updates,
420                                }),
421                            )))
422                        } else {
423                            None
424                        }
425                    }
426                    SubscribeResponse::DroppedAt(frontier) => {
427                        entry
428                            .frontiers
429                            .update_iter(frontier.iter().map(|t| (t.clone(), -1)));
430
431                        if entry.dropped {
432                            None
433                        } else {
434                            entry.dropped = true;
435                            Some(Ok(ComputeResponse::SubscribeResponse(
436                                id,
437                                SubscribeResponse::DroppedAt(frontier),
438                            )))
439                        }
440                    }
441                };
442
443                if entry.frontiers.frontier().is_empty() {
444                    // All shards have reported advancement to the empty frontier or dropping, so
445                    // we do not expect further updates for this subscribe.
446                    self.pending_subscribes.remove(&id);
447                }
448
449                emit_response
450            }
451            ComputeResponse::CopyToResponse(id, response) => {
452                // Incorporate new copy to responses; awaiting all responses.
453                let entry = self
454                    .copy_to_responses
455                    .entry(id)
456                    .or_insert_with(Default::default);
457                let novel = entry.insert(shard_id, response);
458                assert_none!(novel, "Duplicate copy to response");
459                // We may be ready to respond.
460                if entry.len() == self.parts {
461                    let mut response = CopyToResponse::RowCount(0);
462                    for (_part, r) in std::mem::take(entry).into_iter() {
463                        response = match (response, r) {
464                            // It's important that we receive all the `Dropped` messages as well
465                            // so that the `copy_to_responses` state can be cleared.
466                            (_, CopyToResponse::Dropped) => CopyToResponse::Dropped,
467                            (CopyToResponse::Dropped, _) => CopyToResponse::Dropped,
468                            (_, CopyToResponse::Error(e)) => CopyToResponse::Error(e),
469                            (CopyToResponse::Error(e), _) => CopyToResponse::Error(e),
470                            (CopyToResponse::RowCount(r1), CopyToResponse::RowCount(r2)) => {
471                                CopyToResponse::RowCount(r1 + r2)
472                            }
473                        };
474                    }
475                    self.copy_to_responses.remove(&id);
476                    Some(Ok(ComputeResponse::CopyToResponse(id, response)))
477                } else {
478                    None
479                }
480            }
481            response @ ComputeResponse::Status(_) => {
482                // Pass through status responses.
483                Some(Ok(response))
484            }
485        }
486    }
487}
488
489/// Tracked frontiers for an index or a sink collection.
490///
491/// Each frontier is maintained both as a `MutableAntichain` across all partitions and individually
492/// for each partition.
493#[derive(Debug)]
494struct TrackedFrontiers<T> {
495    /// The tracked write frontier.
496    write_frontier: (MutableAntichain<T>, Vec<Antichain<T>>),
497    /// The tracked input frontier.
498    input_frontier: (MutableAntichain<T>, Vec<Antichain<T>>),
499    /// The tracked output frontier.
500    output_frontier: (MutableAntichain<T>, Vec<Antichain<T>>),
501}
502
503impl<T> TrackedFrontiers<T>
504where
505    T: timely::progress::Timestamp + Lattice,
506{
507    /// Initializes frontier tracking state for a new collection.
508    fn new(parts: usize) -> Self {
509        // TODO(benesch): fix this dangerous use of `as`.
510        #[allow(clippy::as_conversions)]
511        let parts_diff = parts as i64;
512
513        let mut frontier = MutableAntichain::new();
514        frontier.update_iter([(T::minimum(), parts_diff)]);
515        let part_frontiers = vec![Antichain::from_elem(T::minimum()); parts];
516        let frontier_entry = (frontier, part_frontiers);
517
518        Self {
519            write_frontier: frontier_entry.clone(),
520            input_frontier: frontier_entry.clone(),
521            output_frontier: frontier_entry,
522        }
523    }
524
525    /// Returns whether all tracked frontiers have advanced to the empty frontier.
526    fn all_empty(&self) -> bool {
527        self.write_frontier.0.frontier().is_empty()
528            && self.input_frontier.0.frontier().is_empty()
529            && self.output_frontier.0.frontier().is_empty()
530    }
531
532    /// Updates write frontier tracking with a new shard frontier.
533    ///
534    /// If this causes the global write frontier to advance, the advanced frontier is returned.
535    fn update_write_frontier(
536        &mut self,
537        shard_id: usize,
538        new_shard_frontier: &Antichain<T>,
539    ) -> Option<Antichain<T>> {
540        Self::update_frontier(&mut self.write_frontier, shard_id, new_shard_frontier)
541    }
542
543    /// Updates input frontier tracking with a new shard frontier.
544    ///
545    /// If this causes the global input frontier to advance, the advanced frontier is returned.
546    fn update_input_frontier(
547        &mut self,
548        shard_id: usize,
549        new_shard_frontier: &Antichain<T>,
550    ) -> Option<Antichain<T>> {
551        Self::update_frontier(&mut self.input_frontier, shard_id, new_shard_frontier)
552    }
553
554    /// Updates output frontier tracking with a new shard frontier.
555    ///
556    /// If this causes the global output frontier to advance, the advanced frontier is returned.
557    fn update_output_frontier(
558        &mut self,
559        shard_id: usize,
560        new_shard_frontier: &Antichain<T>,
561    ) -> Option<Antichain<T>> {
562        Self::update_frontier(&mut self.output_frontier, shard_id, new_shard_frontier)
563    }
564
565    /// Updates the provided frontier entry with a new shard frontier.
566    fn update_frontier(
567        entry: &mut (MutableAntichain<T>, Vec<Antichain<T>>),
568        shard_id: usize,
569        new_shard_frontier: &Antichain<T>,
570    ) -> Option<Antichain<T>> {
571        let (frontier, shard_frontiers) = entry;
572
573        let old_frontier = frontier.frontier().to_owned();
574        let shard_frontier = &mut shard_frontiers[shard_id];
575        frontier.update_iter(shard_frontier.iter().map(|t| (t.clone(), -1)));
576        shard_frontier.join_assign(new_shard_frontier);
577        frontier.update_iter(shard_frontier.iter().map(|t| (t.clone(), 1)));
578
579        let new_frontier = frontier.frontier();
580
581        if PartialOrder::less_than(&old_frontier.borrow(), &new_frontier) {
582            Some(new_frontier.to_owned())
583        } else {
584            None
585        }
586    }
587}
588
589#[derive(Debug)]
590struct PendingSubscribe<T> {
591    /// The subscribe frontiers of the partitioned shards.
592    frontiers: MutableAntichain<T>,
593    /// The updates we are holding back until their timestamps are complete.
594    stashed_updates: Result<Vec<(T, Row, Diff)>, String>,
595    /// Whether we have already emitted a `DroppedAt` response for this subscribe.
596    ///
597    /// This field is used to ensure we emit such a response only once.
598    dropped: bool,
599}
600
601impl<T: ComputeControllerTimestamp> PendingSubscribe<T> {
602    fn new(parts: usize) -> Self {
603        let mut frontiers = MutableAntichain::new();
604        // TODO(benesch): fix this dangerous use of `as`.
605        #[allow(clippy::as_conversions)]
606        frontiers.update_iter([(T::minimum(), parts as i64)]);
607
608        Self {
609            frontiers,
610            stashed_updates: Ok(Vec::new()),
611            dropped: false,
612        }
613    }
614}