mz_clusterd_test_driver/
responses.rs1use std::collections::BTreeMap;
17use std::sync::{Arc, Mutex};
18
19use mz_compute_client::protocol::command::ComputeCommand;
20use mz_compute_client::protocol::response::{
21 ComputeResponse, FrontiersResponse, PeekResponse, SubscribeResponse,
22};
23use mz_repr::{GlobalId, Row, Timestamp};
24use mz_service::client::GenericClient;
25use timely::progress::Antichain;
26use tokio::sync::{broadcast, oneshot, watch};
27
28use crate::ctp::ComputeCtpClient;
29
30type FrontierTx = watch::Sender<FrontiersResponse>;
31type FrontierRx = watch::Receiver<FrontiersResponse>;
32
33struct SubscribeState {
41 updates: Vec<(Row, Timestamp, i64)>,
42 upper_tx: watch::Sender<Antichain<Timestamp>>,
43 error: Option<String>,
44}
45
46impl SubscribeState {
47 fn new() -> Self {
48 SubscribeState {
49 updates: Vec::new(),
50 upper_tx: watch::channel(Antichain::from_elem(Timestamp::default())).0,
51 error: None,
52 }
53 }
54}
55
56struct Shared {
57 frontiers: BTreeMap<GlobalId, FrontierTx>,
58 peeks: BTreeMap<uuid::Uuid, oneshot::Sender<PeekResponse>>,
59 subscribes: BTreeMap<GlobalId, SubscribeState>,
60 raw: broadcast::Sender<ComputeResponse>,
61}
62
63#[derive(Clone)]
66pub struct Responses {
67 shared: Arc<Mutex<Shared>>,
68}
69
70impl Responses {
71 pub fn spawn(mut client: ComputeCtpClient) -> (Self, ComputeSender) {
73 let (raw_tx, _) = broadcast::channel(1024);
74 let shared = Arc::new(Mutex::new(Shared {
75 frontiers: BTreeMap::new(),
76 peeks: BTreeMap::new(),
77 subscribes: BTreeMap::new(),
78 raw: raw_tx,
79 }));
80 let pump_shared = Arc::clone(&shared);
81 let (cmd_tx, mut cmd_rx) = tokio::sync::mpsc::unbounded_channel::<ComputeCommand>();
82 mz_ore::task::spawn(|| "compute_response_pump", async move {
83 loop {
84 tokio::select! {
85 cmd = cmd_rx.recv() => match cmd {
86 Some(cmd) => {
89 if let Err(e) = client.send(cmd).await {
90 tracing::error!("compute command send failed: {e}");
91 break;
92 }
93 }
94 None => break,
95 },
96 resp = client.recv() => match resp {
97 Ok(Some(resp)) => Self::dispatch(&pump_shared, resp),
98 Ok(None) => {
101 tracing::warn!("clusterd closed the compute connection");
102 break;
103 }
104 Err(e) => {
105 tracing::error!("compute response recv failed: {e}");
106 break;
107 }
108 },
109 }
110 }
111 });
112 (Responses { shared }, ComputeSender { tx: cmd_tx })
113 }
114
115 fn dispatch(shared: &Arc<Mutex<Shared>>, resp: ComputeResponse) {
116 let mut g = shared.lock().expect("lock");
117 let _ = g.raw.send(resp.clone());
118 match resp {
119 ComputeResponse::Frontiers(id, f) => {
120 let tx = g
121 .frontiers
122 .entry(id)
123 .or_insert_with(|| watch::channel(FrontiersResponse::default()).0);
124 let mut cur = tx.borrow().clone();
125 if f.write_frontier.is_some() {
126 cur.write_frontier = f.write_frontier;
127 }
128 if f.input_frontier.is_some() {
129 cur.input_frontier = f.input_frontier;
130 }
131 if f.output_frontier.is_some() {
132 cur.output_frontier = f.output_frontier;
133 }
134 let _ = tx.send(cur);
135 }
136 ComputeResponse::PeekResponse(uuid, pr, _otel) => {
137 if let Some(tx) = g.peeks.remove(&uuid) {
138 let _ = tx.send(pr);
139 }
140 }
141 ComputeResponse::SubscribeResponse(id, sr) => {
142 let state = g.subscribes.entry(id).or_insert_with(SubscribeState::new);
143 let upper = match sr {
144 SubscribeResponse::Batch(batch) => {
145 match batch.updates {
146 Ok(collections) => {
147 for collection in collections {
148 for (row, ts, diff) in collection.iter() {
149 state.updates.push((
150 row.to_owned(),
151 *ts,
152 diff.into_inner(),
153 ));
154 }
155 }
156 }
157 Err(e) => {
160 if state.error.is_none() {
161 state.error = Some(e);
162 }
163 }
164 }
165 batch.upper
166 }
167 SubscribeResponse::DroppedAt(frontier) => frontier,
170 };
171 let _ = state.upper_tx.send(upper);
172 }
173 _ => {}
174 }
175 }
176
177 pub fn frontier(&self, id: GlobalId) -> FrontierRx {
180 let mut g = self.shared.lock().expect("lock");
181 g.frontiers
182 .entry(id)
183 .or_insert_with(|| watch::channel(FrontiersResponse::default()).0)
184 .subscribe()
185 }
186
187 pub fn subscribe_raw(&self) -> broadcast::Receiver<ComputeResponse> {
189 self.shared.lock().expect("lock").raw.subscribe()
190 }
191
192 pub fn register_peek(&self, uuid: uuid::Uuid) -> oneshot::Receiver<PeekResponse> {
194 let (tx, rx) = oneshot::channel();
195 self.shared.lock().expect("lock").peeks.insert(uuid, tx);
196 rx
197 }
198
199 pub fn ensure_subscribe(&self, id: GlobalId) -> watch::Receiver<Antichain<Timestamp>> {
203 let mut g = self.shared.lock().expect("lock");
204 g.subscribes
205 .entry(id)
206 .or_insert_with(SubscribeState::new)
207 .upper_tx
208 .subscribe()
209 }
210
211 pub fn drain_subscribe(&self, id: GlobalId) -> anyhow::Result<Vec<(Row, Timestamp, i64)>> {
216 let mut g = self.shared.lock().expect("lock");
217 let state = g
218 .subscribes
219 .get_mut(&id)
220 .ok_or_else(|| anyhow::anyhow!("no subscribe registered for {id}"))?;
221 if let Some(e) = &state.error {
222 anyhow::bail!("subscribe {id} reported an error: {e}");
223 }
224 Ok(std::mem::take(&mut state.updates))
225 }
226}
227
228#[derive(Clone)]
230pub struct ComputeSender {
231 tx: tokio::sync::mpsc::UnboundedSender<ComputeCommand>,
232}
233
234impl ComputeSender {
235 pub fn send(&self, cmd: ComputeCommand) -> anyhow::Result<()> {
236 self.tx
237 .send(cmd)
238 .map_err(|_| anyhow::anyhow!("pump task gone"))?;
239 Ok(())
240 }
241}
242
243#[cfg(test)]
244mod tests {
245 use super::*;
246 use mz_repr::Timestamp;
247 use timely::progress::Antichain;
248
249 fn empty_shared() -> Arc<Mutex<Shared>> {
250 let (raw, _) = broadcast::channel(16);
251 Arc::new(Mutex::new(Shared {
252 frontiers: BTreeMap::new(),
253 peeks: BTreeMap::new(),
254 subscribes: BTreeMap::new(),
255 raw,
256 }))
257 }
258
259 #[mz_ore::test]
260 fn dispatch_merges_frontier_and_broadcasts() {
261 let shared = empty_shared();
262 let id = GlobalId::User(1);
263 let rx = {
264 let mut g = shared.lock().unwrap();
265 let (tx, rx) = watch::channel(FrontiersResponse::default());
266 g.frontiers.insert(id, tx);
267 rx
268 };
269 let mut raw_rx = shared.lock().unwrap().raw.subscribe();
270
271 Responses::dispatch(
272 &shared,
273 ComputeResponse::Frontiers(
274 id,
275 FrontiersResponse {
276 output_frontier: Some(Antichain::from_elem(Timestamp::from(5))),
277 ..Default::default()
278 },
279 ),
280 );
281 assert_eq!(
282 rx.borrow().output_frontier,
283 Some(Antichain::from_elem(Timestamp::from(5)))
284 );
285 Responses::dispatch(
286 &shared,
287 ComputeResponse::Frontiers(
288 id,
289 FrontiersResponse {
290 input_frontier: Some(Antichain::from_elem(Timestamp::from(3))),
291 ..Default::default()
292 },
293 ),
294 );
295 assert_eq!(
296 rx.borrow().output_frontier,
297 Some(Antichain::from_elem(Timestamp::from(5)))
298 );
299 assert_eq!(
300 rx.borrow().input_frontier,
301 Some(Antichain::from_elem(Timestamp::from(3)))
302 );
303 assert!(raw_rx.try_recv().is_ok());
304 }
305}