Skip to main content

mz_clusterd_test_driver/
responses.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Owns the receive side of the CTP connection: a background task that pumps
11//! `ComputeResponse`s into per-id frontier watches, per-uuid peek channels, and
12//! a raw broadcast. The mechanism does not curate which responses or which
13//! frontier fields a use case may observe: frontier watches keep the full
14//! merged `FrontiersResponse`, and the raw broadcast carries every response.
15
16use std::collections::BTreeMap;
17use std::sync::{Arc, Mutex};
18
19use mz_compute_client::protocol::command::ComputeCommand;
20use mz_compute_client::protocol::response::{
21    ComputeResponse, FrontiersResponse, PeekResponse, SubscribeResponse,
22};
23use mz_repr::{GlobalId, Row, Timestamp};
24use mz_service::client::GenericClient;
25use timely::progress::Antichain;
26use tokio::sync::{broadcast, oneshot, watch};
27
28use crate::ctp::ComputeCtpClient;
29
30type FrontierTx = watch::Sender<FrontiersResponse>;
31type FrontierRx = watch::Receiver<FrontiersResponse>;
32
33/// Buffered state for one subscribe sink: its accumulated updates, a watch on its
34/// upper frontier (so a waiter can block until it reaches a target), and the first
35/// error the replica reported, if any.
36///
37/// Subscribe batches arrive asynchronously and out of band of the command that
38/// created the sink, so the pump accumulates them here as they land; the
39/// `await-subscribe` command drains them once the upper reaches its target.
40struct SubscribeState {
41    updates: Vec<(Row, Timestamp, i64)>,
42    upper_tx: watch::Sender<Antichain<Timestamp>>,
43    error: Option<String>,
44}
45
46impl SubscribeState {
47    fn new() -> Self {
48        SubscribeState {
49            updates: Vec::new(),
50            upper_tx: watch::channel(Antichain::from_elem(Timestamp::default())).0,
51            error: None,
52        }
53    }
54}
55
56struct Shared {
57    frontiers: BTreeMap<GlobalId, FrontierTx>,
58    peeks: BTreeMap<uuid::Uuid, oneshot::Sender<PeekResponse>>,
59    subscribes: BTreeMap<GlobalId, SubscribeState>,
60    raw: broadcast::Sender<ComputeResponse>,
61}
62
63/// Handle to the response side. Cloneable view onto frontier watches, peek
64/// routing, and the raw response broadcast.
65#[derive(Clone)]
66pub struct Responses {
67    shared: Arc<Mutex<Shared>>,
68}
69
70impl Responses {
71    /// Spawns the pump task that owns the client's receive half.
72    pub fn spawn(mut client: ComputeCtpClient) -> (Self, ComputeSender) {
73        let (raw_tx, _) = broadcast::channel(1024);
74        let shared = Arc::new(Mutex::new(Shared {
75            frontiers: BTreeMap::new(),
76            peeks: BTreeMap::new(),
77            subscribes: BTreeMap::new(),
78            raw: raw_tx,
79        }));
80        let pump_shared = Arc::clone(&shared);
81        let (cmd_tx, mut cmd_rx) = tokio::sync::mpsc::unbounded_channel::<ComputeCommand>();
82        mz_ore::task::spawn(|| "compute_response_pump", async move {
83            loop {
84                tokio::select! {
85                    cmd = cmd_rx.recv() => match cmd {
86                        // Log a send failure: callers waiting on frontiers/peeks
87                        // would otherwise see only a misleading timeout.
88                        Some(cmd) => {
89                            if let Err(e) = client.send(cmd).await {
90                                tracing::error!("compute command send failed: {e}");
91                                break;
92                            }
93                        }
94                        None => break,
95                    },
96                    resp = client.recv() => match resp {
97                        Ok(Some(resp)) => Self::dispatch(&pump_shared, resp),
98                        // Distinguish a clean EOF from a transport error so that
99                        // an e2e hang has a breadcrumb rather than silent death.
100                        Ok(None) => {
101                            tracing::warn!("clusterd closed the compute connection");
102                            break;
103                        }
104                        Err(e) => {
105                            tracing::error!("compute response recv failed: {e}");
106                            break;
107                        }
108                    },
109                }
110            }
111        });
112        (Responses { shared }, ComputeSender { tx: cmd_tx })
113    }
114
115    fn dispatch(shared: &Arc<Mutex<Shared>>, resp: ComputeResponse) {
116        let mut g = shared.lock().expect("lock");
117        let _ = g.raw.send(resp.clone());
118        match resp {
119            ComputeResponse::Frontiers(id, f) => {
120                let tx = g
121                    .frontiers
122                    .entry(id)
123                    .or_insert_with(|| watch::channel(FrontiersResponse::default()).0);
124                let mut cur = tx.borrow().clone();
125                if f.write_frontier.is_some() {
126                    cur.write_frontier = f.write_frontier;
127                }
128                if f.input_frontier.is_some() {
129                    cur.input_frontier = f.input_frontier;
130                }
131                if f.output_frontier.is_some() {
132                    cur.output_frontier = f.output_frontier;
133                }
134                let _ = tx.send(cur);
135            }
136            ComputeResponse::PeekResponse(uuid, pr, _otel) => {
137                if let Some(tx) = g.peeks.remove(&uuid) {
138                    let _ = tx.send(pr);
139                }
140            }
141            ComputeResponse::SubscribeResponse(id, sr) => {
142                let state = g.subscribes.entry(id).or_insert_with(SubscribeState::new);
143                let upper = match sr {
144                    SubscribeResponse::Batch(batch) => {
145                        match batch.updates {
146                            Ok(collections) => {
147                                for collection in collections {
148                                    for (row, ts, diff) in collection.iter() {
149                                        state.updates.push((
150                                            row.to_owned(),
151                                            *ts,
152                                            diff.into_inner(),
153                                        ));
154                                    }
155                                }
156                            }
157                            // Record the first error; the size-limit / internal error
158                            // path replaces the updates with a message.
159                            Err(e) => {
160                                if state.error.is_none() {
161                                    state.error = Some(e);
162                                }
163                            }
164                        }
165                        batch.upper
166                    }
167                    // A drop leaves later updates unspecified; treat the drop frontier
168                    // as the final upper so a waiter unblocks.
169                    SubscribeResponse::DroppedAt(frontier) => frontier,
170                };
171                let _ = state.upper_tx.send(upper);
172            }
173            _ => {}
174        }
175    }
176
177    /// Returns a watch receiver for an id's full (merged) frontiers, created
178    /// lazily. Use cases read whichever of write/input/output they need.
179    pub fn frontier(&self, id: GlobalId) -> FrontierRx {
180        let mut g = self.shared.lock().expect("lock");
181        g.frontiers
182            .entry(id)
183            .or_insert_with(|| watch::channel(FrontiersResponse::default()).0)
184            .subscribe()
185    }
186
187    /// Subscribes to every `ComputeResponse` the replica sends.
188    pub fn subscribe_raw(&self) -> broadcast::Receiver<ComputeResponse> {
189        self.shared.lock().expect("lock").raw.subscribe()
190    }
191
192    /// Registers interest in a peek's response before the Peek command is sent.
193    pub fn register_peek(&self, uuid: uuid::Uuid) -> oneshot::Receiver<PeekResponse> {
194        let (tx, rx) = oneshot::channel();
195        self.shared.lock().expect("lock").peeks.insert(uuid, tx);
196        rx
197    }
198
199    /// Ensures a subscribe buffer exists for `id` and returns a watch receiver for
200    /// its upper frontier, created lazily. Call this before scheduling the sink so
201    /// the upper watch is observable; the pump accumulates batches regardless.
202    pub fn ensure_subscribe(&self, id: GlobalId) -> watch::Receiver<Antichain<Timestamp>> {
203        let mut g = self.shared.lock().expect("lock");
204        g.subscribes
205            .entry(id)
206            .or_insert_with(SubscribeState::new)
207            .upper_tx
208            .subscribe()
209    }
210
211    /// Drains the buffered updates for subscribe `id`, returning them as
212    /// `(row, time, diff)` triples. Errors if the replica reported a subscribe
213    /// error (e.g. a result-size overflow), so the assertion fails loudly rather
214    /// than on a silently truncated batch.
215    pub fn drain_subscribe(&self, id: GlobalId) -> anyhow::Result<Vec<(Row, Timestamp, i64)>> {
216        let mut g = self.shared.lock().expect("lock");
217        let state = g
218            .subscribes
219            .get_mut(&id)
220            .ok_or_else(|| anyhow::anyhow!("no subscribe registered for {id}"))?;
221        if let Some(e) = &state.error {
222            anyhow::bail!("subscribe {id} reported an error: {e}");
223        }
224        Ok(std::mem::take(&mut state.updates))
225    }
226}
227
228/// Send half: forwards commands into the pump task that owns the client.
229#[derive(Clone)]
230pub struct ComputeSender {
231    tx: tokio::sync::mpsc::UnboundedSender<ComputeCommand>,
232}
233
234impl ComputeSender {
235    pub fn send(&self, cmd: ComputeCommand) -> anyhow::Result<()> {
236        self.tx
237            .send(cmd)
238            .map_err(|_| anyhow::anyhow!("pump task gone"))?;
239        Ok(())
240    }
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246    use mz_repr::Timestamp;
247    use timely::progress::Antichain;
248
249    fn empty_shared() -> Arc<Mutex<Shared>> {
250        let (raw, _) = broadcast::channel(16);
251        Arc::new(Mutex::new(Shared {
252            frontiers: BTreeMap::new(),
253            peeks: BTreeMap::new(),
254            subscribes: BTreeMap::new(),
255            raw,
256        }))
257    }
258
259    #[mz_ore::test]
260    fn dispatch_merges_frontier_and_broadcasts() {
261        let shared = empty_shared();
262        let id = GlobalId::User(1);
263        let rx = {
264            let mut g = shared.lock().unwrap();
265            let (tx, rx) = watch::channel(FrontiersResponse::default());
266            g.frontiers.insert(id, tx);
267            rx
268        };
269        let mut raw_rx = shared.lock().unwrap().raw.subscribe();
270
271        Responses::dispatch(
272            &shared,
273            ComputeResponse::Frontiers(
274                id,
275                FrontiersResponse {
276                    output_frontier: Some(Antichain::from_elem(Timestamp::from(5))),
277                    ..Default::default()
278                },
279            ),
280        );
281        assert_eq!(
282            rx.borrow().output_frontier,
283            Some(Antichain::from_elem(Timestamp::from(5)))
284        );
285        Responses::dispatch(
286            &shared,
287            ComputeResponse::Frontiers(
288                id,
289                FrontiersResponse {
290                    input_frontier: Some(Antichain::from_elem(Timestamp::from(3))),
291                    ..Default::default()
292                },
293            ),
294        );
295        assert_eq!(
296            rx.borrow().output_frontier,
297            Some(Antichain::from_elem(Timestamp::from(5)))
298        );
299        assert_eq!(
300            rx.borrow().input_frontier,
301            Some(Antichain::from_elem(Timestamp::from(3)))
302        );
303        assert!(raw_rx.try_recv().is_ok());
304    }
305}