mz_compute_client/
service.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10// Tonic generates code that violates clippy lints.
11// TODO: Remove this once tonic does not produce this code anymore.
12#![allow(clippy::as_conversions, clippy::clone_on_ref_ptr)]
13
14//! Compute layer client and server.
15
16use std::collections::{BTreeMap, BTreeSet};
17use std::mem;
18
19use async_trait::async_trait;
20use bytesize::ByteSize;
21use differential_dataflow::consolidation::consolidate_updates;
22use differential_dataflow::lattice::Lattice;
23use mz_expr::row::RowCollection;
24use mz_ore::cast::CastInto;
25use mz_ore::soft_panic_or_log;
26use mz_ore::tracing::OpenTelemetryContext;
27use mz_repr::{Diff, GlobalId, Row};
28use mz_service::client::{GenericClient, Partitionable, PartitionedState};
29use mz_service::grpc::{GrpcClient, GrpcServer, ProtoServiceTypes, ResponseStream};
30use timely::PartialOrder;
31use timely::progress::frontier::{Antichain, MutableAntichain};
32use tonic::{Request, Status, Streaming};
33use uuid::Uuid;
34
35use crate::controller::ComputeControllerTimestamp;
36use crate::metrics::ReplicaMetrics;
37use crate::protocol::command::{ComputeCommand, ProtoComputeCommand};
38use crate::protocol::response::{
39    ComputeResponse, CopyToResponse, FrontiersResponse, PeekResponse, ProtoComputeResponse,
40    StashedPeekResponse, SubscribeBatch, SubscribeResponse,
41};
42use crate::service::proto_compute_server::ProtoCompute;
43
44include!(concat!(env!("OUT_DIR"), "/mz_compute_client.service.rs"));
45
46/// A client to a compute server.
47pub trait ComputeClient<T = mz_repr::Timestamp>:
48    GenericClient<ComputeCommand<T>, ComputeResponse<T>>
49{
50}
51
52impl<C, T> ComputeClient<T> for C where C: GenericClient<ComputeCommand<T>, ComputeResponse<T>> {}
53
54#[async_trait]
55impl<T: Send> GenericClient<ComputeCommand<T>, ComputeResponse<T>> for Box<dyn ComputeClient<T>> {
56    async fn send(&mut self, cmd: ComputeCommand<T>) -> Result<(), anyhow::Error> {
57        (**self).send(cmd).await
58    }
59
60    /// # Cancel safety
61    ///
62    /// This method is cancel safe. If `recv` is used as the event in a [`tokio::select!`]
63    /// statement and some other branch completes first, it is guaranteed that no messages were
64    /// received by this client.
65    async fn recv(&mut self) -> Result<Option<ComputeResponse<T>>, anyhow::Error> {
66        // `GenericClient::recv` is required to be cancel safe.
67        (**self).recv().await
68    }
69}
70
71/// TODO(database-issues#7533): Add documentation.
72#[derive(Debug, Clone)]
73pub enum ComputeProtoServiceTypes {}
74
75impl ProtoServiceTypes for ComputeProtoServiceTypes {
76    type PC = ProtoComputeCommand;
77    type PR = ProtoComputeResponse;
78    type STATS = ReplicaMetrics;
79    const URL: &'static str = "/mz_compute_client.service.ProtoCompute/CommandResponseStream";
80}
81
82/// TODO(database-issues#7533): Add documentation.
83pub type ComputeGrpcClient = GrpcClient<ComputeProtoServiceTypes>;
84
85#[async_trait]
86impl<F, G> ProtoCompute for GrpcServer<F>
87where
88    F: Fn() -> G + Send + Sync + 'static,
89    G: ComputeClient + 'static,
90{
91    type CommandResponseStreamStream = ResponseStream<ProtoComputeResponse>;
92
93    async fn command_response_stream(
94        &self,
95        request: Request<Streaming<ProtoComputeCommand>>,
96    ) -> Result<tonic::Response<Self::CommandResponseStreamStream>, Status> {
97        self.forward_bidi_stream(request).await
98    }
99}
100
101/// Maintained state for partitioned compute clients.
102///
103/// This helper type unifies the responses of multiple partitioned workers in order to present as a
104/// single worker:
105///
106///   * It emits `Frontiers` responses reporting the minimum/meet of frontiers reported by the
107///     individual workers.
108///   * It emits `PeekResponse`s and `SubscribeResponse`s reporting the union of the responses
109///     received from the workers.
110///
111/// In the compute communication stack, this client is instantiated several times:
112///
113///   * One instance on the controller side, dispatching between cluster processes.
114///   * One instance in each cluster process, dispatching between timely worker threads.
115///
116/// Note that because compute commands, except `Hello` and `UpdateConfiguration`, are only
117/// sent to the first process, the cluster-side instances of `PartitionedComputeState` are not
118/// guaranteed to see all compute commands. Or more specifically: The instance running inside
119/// process 0 sees all commands, whereas the instances running inside the other processes only see
120/// `Hello` and `UpdateConfiguration`. The `PartitionedComputeState` implementation must be
121/// able to cope with this limited visibility. It does so by performing most of its state management
122/// based on observed compute responses rather than commands.
123#[derive(Debug)]
124pub struct PartitionedComputeState<T> {
125    /// Number of partitions the state machine represents.
126    parts: usize,
127    /// The maximum result size this state machine can return.
128    ///
129    /// This is updated upon receiving [`ComputeCommand::UpdateConfiguration`]s.
130    max_result_size: u64,
131    /// Tracked frontiers for indexes and sinks.
132    ///
133    /// Frontier tracking for a collection is initialized when the first `Frontiers` response
134    /// for that collection is received. Frontier tracking is ceased when all shards have reported
135    /// advancement to the empty frontier for all frontier kinds.
136    ///
137    /// The compute protocol requires that shards always emit `Frontiers` responses reporting empty
138    /// frontiers for all frontier kinds when a collection is dropped. It further requires that no
139    /// further `Frontier` responses are emitted for a collection after the empty frontiers were
140    /// reported. These properties ensure that a) we always cease frontier tracking for collections
141    /// that have been dropped and b) frontier tracking for a collection is not re-initialized
142    /// after it was ceased.
143    frontiers: BTreeMap<GlobalId, TrackedFrontiers<T>>,
144    /// For each in-progress peek the response data received so far, and the set of shards that
145    /// provided responses already.
146    ///
147    /// Tracking of responses for a peek is initialized when the first `PeekResponse` for that peek
148    /// is received. Once all shards have provided a `PeekResponse`, a unified peek response is
149    /// emitted and the peek tracking state is dropped again.
150    ///
151    /// The compute protocol requires that exactly one response is emitted for each peek. This
152    /// property ensures that a) we can eventually drop the tracking state maintained for a peek
153    /// and b) we won't re-initialize tracking for a peek we have already served.
154    peek_responses: BTreeMap<Uuid, (PeekResponse, BTreeSet<usize>)>,
155    /// For each in-progress copy-to the response data received so far, and the set of shards that
156    /// provided responses already.
157    ///
158    /// Tracking of responses for a COPY TO is initialized when the first `CopyResponse` for that command
159    /// is received. Once all shards have provided a `CopyResponse`, a unified copy response is
160    /// emitted and the copy_to tracking state is dropped again.
161    ///
162    /// The compute protocol requires that exactly one response is emitted for each COPY TO command. This
163    /// property ensures that a) we can eventually drop the tracking state maintained for a copy
164    /// and b) we won't re-initialize tracking for a copy we have already served.
165    copy_to_responses: BTreeMap<GlobalId, (CopyToResponse, BTreeSet<usize>)>,
166    /// Tracks in-progress `SUBSCRIBE`s, and the stashed rows we are holding back until their
167    /// timestamps are complete.
168    ///
169    /// The updates may be `Err` if any of the batches have reported an error, in which case the
170    /// subscribe is permanently borked.
171    ///
172    /// Tracking of a subscribe is initialized when the first `SubscribeResponse` for that
173    /// subscribe is received. Once all shards have emitted an "end-of-subscribe" response the
174    /// subscribe tracking state is dropped again.
175    ///
176    /// The compute protocol requires that for a subscribe that shuts down an end-of-subscribe
177    /// response is emitted:
178    ///
179    ///   * Either a `Batch` response reporting advancement to the empty frontier...
180    ///   * ... or a `DroppedAt` response reporting that the subscribe was dropped before
181    ///     completing.
182    ///
183    /// The compute protocol further requires that no further `SubscribeResponse`s are emitted for
184    /// a subscribe after an end-of-subscribe was reported.
185    ///
186    /// These two properties ensure that a) once a subscribe has shut down, we can eventually drop
187    /// the tracking state maintained for it and b) we won't re-initialize tracking for a subscribe
188    /// we have already dropped.
189    pending_subscribes: BTreeMap<GlobalId, PendingSubscribe<T>>,
190}
191
192impl<T> Partitionable<ComputeCommand<T>, ComputeResponse<T>>
193    for (ComputeCommand<T>, ComputeResponse<T>)
194where
195    T: ComputeControllerTimestamp,
196{
197    type PartitionedState = PartitionedComputeState<T>;
198
199    fn new(parts: usize) -> PartitionedComputeState<T> {
200        PartitionedComputeState {
201            parts,
202            max_result_size: u64::MAX,
203            frontiers: BTreeMap::new(),
204            peek_responses: BTreeMap::new(),
205            pending_subscribes: BTreeMap::new(),
206            copy_to_responses: BTreeMap::new(),
207        }
208    }
209}
210
211impl<T> PartitionedComputeState<T>
212where
213    T: ComputeControllerTimestamp,
214{
215    /// Observes commands that move past.
216    pub fn observe_command(&mut self, command: &ComputeCommand<T>) {
217        match command {
218            ComputeCommand::UpdateConfiguration(config) => {
219                if let Some(max_result_size) = config.max_result_size {
220                    self.max_result_size = max_result_size;
221                }
222            }
223            _ => {
224                // We are not guaranteed to observe other compute commands. We
225                // must therefore not add any logic here that relies on doing so.
226            }
227        }
228    }
229
230    /// Absorb a [`ComputeResponse::Frontiers`].
231    fn absorb_frontiers(
232        &mut self,
233        shard_id: usize,
234        collection_id: GlobalId,
235        frontiers: FrontiersResponse<T>,
236    ) -> Option<ComputeResponse<T>> {
237        let tracked = self
238            .frontiers
239            .entry(collection_id)
240            .or_insert_with(|| TrackedFrontiers::new(self.parts));
241
242        let write_frontier = frontiers
243            .write_frontier
244            .and_then(|f| tracked.update_write_frontier(shard_id, &f));
245        let input_frontier = frontiers
246            .input_frontier
247            .and_then(|f| tracked.update_input_frontier(shard_id, &f));
248        let output_frontier = frontiers
249            .output_frontier
250            .and_then(|f| tracked.update_output_frontier(shard_id, &f));
251
252        let frontiers = FrontiersResponse {
253            write_frontier,
254            input_frontier,
255            output_frontier,
256        };
257        let result = frontiers
258            .has_updates()
259            .then_some(ComputeResponse::Frontiers(collection_id, frontiers));
260
261        if tracked.all_empty() {
262            // All shards have reported advancement to the empty frontier, so we do not
263            // expect further updates for this collection.
264            self.frontiers.remove(&collection_id);
265        }
266
267        result
268    }
269
270    /// Absorb a [`ComputeResponse::PeekResponse`].
271    fn absorb_peek_response(
272        &mut self,
273        shard_id: usize,
274        uuid: Uuid,
275        response: PeekResponse,
276        otel_ctx: OpenTelemetryContext,
277    ) -> Option<ComputeResponse<T>> {
278        let (merged, ready_shards) = self.peek_responses.entry(uuid).or_insert((
279            PeekResponse::Rows(RowCollection::default()),
280            BTreeSet::new(),
281        ));
282
283        let first = ready_shards.insert(shard_id);
284        assert!(first, "duplicate peek response");
285
286        let resp1 = mem::replace(merged, PeekResponse::Canceled);
287        *merged = merge_peek_responses(resp1, response, self.max_result_size);
288
289        if ready_shards.len() == self.parts {
290            let (response, _) = self.peek_responses.remove(&uuid).unwrap();
291            Some(ComputeResponse::PeekResponse(uuid, response, otel_ctx))
292        } else {
293            None
294        }
295    }
296
297    /// Absorb a [`ComputeResponse::CopyToResponse`].
298    fn absorb_copy_to_response(
299        &mut self,
300        shard_id: usize,
301        copyto_id: GlobalId,
302        response: CopyToResponse,
303    ) -> Option<ComputeResponse<T>> {
304        use CopyToResponse::*;
305
306        let (merged, ready_shards) = self
307            .copy_to_responses
308            .entry(copyto_id)
309            .or_insert((CopyToResponse::RowCount(0), BTreeSet::new()));
310
311        let first = ready_shards.insert(shard_id);
312        assert!(first, "duplicate copy-to response");
313
314        let resp1 = mem::replace(merged, Dropped);
315        *merged = match (resp1, response) {
316            (Dropped, _) | (_, Dropped) => Dropped,
317            (Error(e), _) | (_, Error(e)) => Error(e),
318            (RowCount(r1), RowCount(r2)) => RowCount(r1 + r2),
319        };
320
321        if ready_shards.len() == self.parts {
322            let (response, _) = self.copy_to_responses.remove(&copyto_id).unwrap();
323            Some(ComputeResponse::CopyToResponse(copyto_id, response))
324        } else {
325            None
326        }
327    }
328
329    /// Absorb a [`ComputeResponse::SubscribeResponse`].
330    fn absorb_subscribe_response(
331        &mut self,
332        subscribe_id: GlobalId,
333        response: SubscribeResponse<T>,
334    ) -> Option<ComputeResponse<T>> {
335        let tracked = self
336            .pending_subscribes
337            .entry(subscribe_id)
338            .or_insert_with(|| PendingSubscribe::new(self.parts));
339
340        let emit_response = match response {
341            SubscribeResponse::Batch(batch) => {
342                let frontiers = &mut tracked.frontiers;
343                let old_frontier = frontiers.frontier().to_owned();
344                frontiers.update_iter(batch.lower.into_iter().map(|t| (t, -1)));
345                frontiers.update_iter(batch.upper.into_iter().map(|t| (t, 1)));
346                let new_frontier = frontiers.frontier().to_owned();
347
348                tracked.stash(batch.updates, self.max_result_size);
349
350                // If the frontier has advanced, it is time to announce subscribe progress. Unless
351                // we have already announced that the subscribe has been dropped, in which case we
352                // must keep quiet.
353                if old_frontier != new_frontier && !tracked.dropped {
354                    let updates = match &mut tracked.stashed_updates {
355                        Ok(stashed_updates) => {
356                            // The compute protocol requires us to only send out consolidated
357                            // batches.
358                            consolidate_updates(stashed_updates);
359
360                            let mut ship = Vec::new();
361                            let mut keep = Vec::new();
362                            for (time, data, diff) in stashed_updates.drain(..) {
363                                if new_frontier.less_equal(&time) {
364                                    keep.push((time, data, diff));
365                                } else {
366                                    ship.push((time, data, diff));
367                                }
368                            }
369                            tracked.stashed_updates = Ok(keep);
370                            Ok(ship)
371                        }
372                        Err(text) => Err(text.clone()),
373                    };
374                    Some(ComputeResponse::SubscribeResponse(
375                        subscribe_id,
376                        SubscribeResponse::Batch(SubscribeBatch {
377                            lower: old_frontier,
378                            upper: new_frontier,
379                            updates,
380                        }),
381                    ))
382                } else {
383                    None
384                }
385            }
386            SubscribeResponse::DroppedAt(frontier) => {
387                tracked
388                    .frontiers
389                    .update_iter(frontier.iter().map(|t| (t.clone(), -1)));
390
391                if tracked.dropped {
392                    None
393                } else {
394                    tracked.dropped = true;
395                    Some(ComputeResponse::SubscribeResponse(
396                        subscribe_id,
397                        SubscribeResponse::DroppedAt(frontier),
398                    ))
399                }
400            }
401        };
402
403        if tracked.frontiers.frontier().is_empty() {
404            // All shards have reported advancement to the empty frontier or dropping, so
405            // we do not expect further updates for this subscribe.
406            self.pending_subscribes.remove(&subscribe_id);
407        }
408
409        emit_response
410    }
411}
412
413impl<T> PartitionedState<ComputeCommand<T>, ComputeResponse<T>> for PartitionedComputeState<T>
414where
415    T: ComputeControllerTimestamp,
416{
417    fn split_command(&mut self, command: ComputeCommand<T>) -> Vec<Option<ComputeCommand<T>>> {
418        self.observe_command(&command);
419
420        // As specified by the compute protocol:
421        //  * Forward `Hello` and `UpdateConfiguration` commands to all shards.
422        //  * Forward all other commands to the first shard only.
423        match command {
424            command @ ComputeCommand::Hello { .. }
425            | command @ ComputeCommand::UpdateConfiguration(_) => {
426                vec![Some(command); self.parts]
427            }
428            command => {
429                let mut r = vec![None; self.parts];
430                r[0] = Some(command);
431                r
432            }
433        }
434    }
435
436    fn absorb_response(
437        &mut self,
438        shard_id: usize,
439        message: ComputeResponse<T>,
440    ) -> Option<Result<ComputeResponse<T>, anyhow::Error>> {
441        let response = match message {
442            ComputeResponse::Frontiers(id, frontiers) => {
443                self.absorb_frontiers(shard_id, id, frontiers)
444            }
445            ComputeResponse::PeekResponse(uuid, response, otel_ctx) => {
446                self.absorb_peek_response(shard_id, uuid, response, otel_ctx)
447            }
448            ComputeResponse::SubscribeResponse(id, response) => {
449                self.absorb_subscribe_response(id, response)
450            }
451            ComputeResponse::CopyToResponse(id, response) => {
452                self.absorb_copy_to_response(shard_id, id, response)
453            }
454            response @ ComputeResponse::Status(_) => {
455                // Pass through status responses.
456                Some(response)
457            }
458        };
459
460        response.map(Ok)
461    }
462}
463
464/// Tracked frontiers for an index or a sink collection.
465///
466/// Each frontier is maintained both as a `MutableAntichain` across all partitions and individually
467/// for each partition.
468#[derive(Debug)]
469struct TrackedFrontiers<T> {
470    /// The tracked write frontier.
471    write_frontier: (MutableAntichain<T>, Vec<Antichain<T>>),
472    /// The tracked input frontier.
473    input_frontier: (MutableAntichain<T>, Vec<Antichain<T>>),
474    /// The tracked output frontier.
475    output_frontier: (MutableAntichain<T>, Vec<Antichain<T>>),
476}
477
478impl<T> TrackedFrontiers<T>
479where
480    T: timely::progress::Timestamp + Lattice,
481{
482    /// Initializes frontier tracking state for a new collection.
483    fn new(parts: usize) -> Self {
484        // TODO(benesch): fix this dangerous use of `as`.
485        #[allow(clippy::as_conversions)]
486        let parts_diff = parts as i64;
487
488        let mut frontier = MutableAntichain::new();
489        frontier.update_iter([(T::minimum(), parts_diff)]);
490        let part_frontiers = vec![Antichain::from_elem(T::minimum()); parts];
491        let frontier_entry = (frontier, part_frontiers);
492
493        Self {
494            write_frontier: frontier_entry.clone(),
495            input_frontier: frontier_entry.clone(),
496            output_frontier: frontier_entry,
497        }
498    }
499
500    /// Returns whether all tracked frontiers have advanced to the empty frontier.
501    fn all_empty(&self) -> bool {
502        self.write_frontier.0.frontier().is_empty()
503            && self.input_frontier.0.frontier().is_empty()
504            && self.output_frontier.0.frontier().is_empty()
505    }
506
507    /// Updates write frontier tracking with a new shard frontier.
508    ///
509    /// If this causes the global write frontier to advance, the advanced frontier is returned.
510    fn update_write_frontier(
511        &mut self,
512        shard_id: usize,
513        new_shard_frontier: &Antichain<T>,
514    ) -> Option<Antichain<T>> {
515        Self::update_frontier(&mut self.write_frontier, shard_id, new_shard_frontier)
516    }
517
518    /// Updates input frontier tracking with a new shard frontier.
519    ///
520    /// If this causes the global input frontier to advance, the advanced frontier is returned.
521    fn update_input_frontier(
522        &mut self,
523        shard_id: usize,
524        new_shard_frontier: &Antichain<T>,
525    ) -> Option<Antichain<T>> {
526        Self::update_frontier(&mut self.input_frontier, shard_id, new_shard_frontier)
527    }
528
529    /// Updates output frontier tracking with a new shard frontier.
530    ///
531    /// If this causes the global output frontier to advance, the advanced frontier is returned.
532    fn update_output_frontier(
533        &mut self,
534        shard_id: usize,
535        new_shard_frontier: &Antichain<T>,
536    ) -> Option<Antichain<T>> {
537        Self::update_frontier(&mut self.output_frontier, shard_id, new_shard_frontier)
538    }
539
540    /// Updates the provided frontier entry with a new shard frontier.
541    fn update_frontier(
542        entry: &mut (MutableAntichain<T>, Vec<Antichain<T>>),
543        shard_id: usize,
544        new_shard_frontier: &Antichain<T>,
545    ) -> Option<Antichain<T>> {
546        let (frontier, shard_frontiers) = entry;
547
548        let old_frontier = frontier.frontier().to_owned();
549        let shard_frontier = &mut shard_frontiers[shard_id];
550        frontier.update_iter(shard_frontier.iter().map(|t| (t.clone(), -1)));
551        shard_frontier.join_assign(new_shard_frontier);
552        frontier.update_iter(shard_frontier.iter().map(|t| (t.clone(), 1)));
553
554        let new_frontier = frontier.frontier();
555
556        if PartialOrder::less_than(&old_frontier.borrow(), &new_frontier) {
557            Some(new_frontier.to_owned())
558        } else {
559            None
560        }
561    }
562}
563
564#[derive(Debug)]
565struct PendingSubscribe<T> {
566    /// The subscribe frontiers of the partitioned shards.
567    frontiers: MutableAntichain<T>,
568    /// The updates we are holding back until their timestamps are complete.
569    stashed_updates: Result<Vec<(T, Row, Diff)>, String>,
570    /// The row size of stashed updates, for `max_result_size` checking.
571    stashed_result_size: usize,
572    /// Whether we have already emitted a `DroppedAt` response for this subscribe.
573    ///
574    /// This field is used to ensure we emit such a response only once.
575    dropped: bool,
576}
577
578impl<T: ComputeControllerTimestamp> PendingSubscribe<T> {
579    fn new(parts: usize) -> Self {
580        let mut frontiers = MutableAntichain::new();
581        // TODO(benesch): fix this dangerous use of `as`.
582        #[allow(clippy::as_conversions)]
583        frontiers.update_iter([(T::minimum(), parts as i64)]);
584
585        Self {
586            frontiers,
587            stashed_updates: Ok(Vec::new()),
588            stashed_result_size: 0,
589            dropped: false,
590        }
591    }
592
593    /// Stash a new batch of updates.
594    ///
595    /// This also implements the short-circuit behavior of error responses, and performs
596    /// `max_result_size` checking.
597    fn stash(&mut self, new_updates: Result<Vec<(T, Row, Diff)>, String>, max_result_size: u64) {
598        match (&mut self.stashed_updates, new_updates) {
599            (Err(_), _) => {
600                // Subscribe is borked; nothing to do.
601                // TODO: Consider refreshing error?
602            }
603            (_, Err(text)) => {
604                self.stashed_updates = Err(text);
605            }
606            (Ok(stashed), Ok(new)) => {
607                let new_size: usize = new.iter().map(|(_, row, _)| row.byte_len()).sum();
608                self.stashed_result_size += new_size;
609
610                if self.stashed_result_size > max_result_size.cast_into() {
611                    self.stashed_updates = Err(format!(
612                        "total result exceeds max size of {}",
613                        ByteSize::b(max_result_size)
614                    ));
615                } else {
616                    stashed.extend(new);
617                }
618            }
619        }
620    }
621}
622
623/// Merge two [`PeekResponse`]s.
624fn merge_peek_responses(
625    resp1: PeekResponse,
626    resp2: PeekResponse,
627    max_result_size: u64,
628) -> PeekResponse {
629    use PeekResponse::*;
630
631    // Cancelations and errors short-circuit. Cancelations take precedence over errors.
632    let (resp1, resp2) = match (resp1, resp2) {
633        (Canceled, _) | (_, Canceled) => return Canceled,
634        (Error(e), _) | (_, Error(e)) => return Error(e),
635        resps => resps,
636    };
637
638    let total_byte_len = resp1.inline_byte_len() + resp2.inline_byte_len();
639    if total_byte_len > max_result_size.cast_into() {
640        // Note: We match on this specific error message in tests so it's important that
641        // nothing else returns the same string.
642        let err = format!(
643            "total result exceeds max size of {}",
644            ByteSize::b(max_result_size)
645        );
646        return Error(err);
647    }
648
649    match (resp1, resp2) {
650        (Rows(mut rows1), Rows(rows2)) => {
651            rows1.merge(&rows2);
652            Rows(rows1)
653        }
654        (Rows(rows), Stashed(mut stashed)) | (Stashed(mut stashed), Rows(rows)) => {
655            stashed.inline_rows.merge(&rows);
656            Stashed(stashed)
657        }
658        (Stashed(stashed1), Stashed(stashed2)) => {
659            // Deconstruct so we don't miss adding new fields. We need to be careful about
660            // merging everything!
661            let StashedPeekResponse {
662                num_rows_batches: num_rows_batches1,
663                encoded_size_bytes: encoded_size_bytes1,
664                relation_desc: relation_desc1,
665                shard_id: shard_id1,
666                batches: mut batches1,
667                inline_rows: mut inline_rows1,
668            } = *stashed1;
669            let StashedPeekResponse {
670                num_rows_batches: num_rows_batches2,
671                encoded_size_bytes: encoded_size_bytes2,
672                relation_desc: relation_desc2,
673                shard_id: shard_id2,
674                batches: mut batches2,
675                inline_rows: inline_rows2,
676            } = *stashed2;
677
678            if shard_id1 != shard_id2 {
679                soft_panic_or_log!(
680                    "shard IDs of stashed responses do not match: \
681                             {shard_id1} != {shard_id2}"
682                );
683                return Error("internal error".into());
684            }
685            if relation_desc1 != relation_desc2 {
686                soft_panic_or_log!(
687                    "relation descs of stashed responses do not match: \
688                             {relation_desc1:?} != {relation_desc2:?}"
689                );
690                return Error("internal error".into());
691            }
692
693            batches1.append(&mut batches2);
694            inline_rows1.merge(&inline_rows2);
695
696            Stashed(Box::new(StashedPeekResponse {
697                num_rows_batches: num_rows_batches1 + num_rows_batches2,
698                encoded_size_bytes: encoded_size_bytes1 + encoded_size_bytes2,
699                relation_desc: relation_desc1,
700                shard_id: shard_id1,
701                batches: batches1,
702                inline_rows: inline_rows1,
703            }))
704        }
705        _ => unreachable!("handled above"),
706    }
707}