mz_txn_wal/
txns.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! An interface for atomic multi-shard writes.
11
12use std::collections::BTreeMap;
13use std::fmt::Debug;
14use std::ops::{Deref, DerefMut};
15use std::sync::Arc;
16
17use differential_dataflow::difference::Semigroup;
18use differential_dataflow::lattice::Lattice;
19use futures::StreamExt;
20use futures::stream::FuturesUnordered;
21use mz_dyncfg::{Config, ConfigSet, ConfigValHandle};
22use mz_ore::collections::HashSet;
23use mz_ore::instrument;
24use mz_persist_client::batch::Batch;
25use mz_persist_client::cfg::USE_CRITICAL_SINCE_TXN;
26use mz_persist_client::critical::SinceHandle;
27use mz_persist_client::write::WriteHandle;
28use mz_persist_client::{Diagnostics, PersistClient, ShardId};
29use mz_persist_types::schema::SchemaId;
30use mz_persist_types::txn::{TxnsCodec, TxnsEntry};
31use mz_persist_types::{Codec, Codec64, Opaque, StepForward};
32use timely::order::TotalOrder;
33use timely::progress::Timestamp;
34use tracing::debug;
35
36use crate::TxnsCodecDefault;
37use crate::metrics::Metrics;
38use crate::txn_cache::{TxnsCache, Unapplied};
39use crate::txn_write::Txn;
40
41/// An interface for atomic multi-shard writes.
42///
43/// This handle is acquired through [Self::open]. Any data shards must be
44/// registered with [Self::register] before use. Transactions are then started
45/// with [Self::begin].
46///
47/// # Implementation Details
48///
49/// The structure of the txns shard is `(ShardId, Vec<u8>)` updates.
50///
51/// The core mechanism is that a txn commits a set of transmittable persist
52/// _batch handles_ as `(ShardId, <opaque blob>)` pairs at a single timestamp.
53/// This contractually both commits the txn and advances the logical upper of
54/// _every_ data shard (not just the ones involved in the txn).
55///
56/// Example:
57///
58/// ```text
59/// // A txn to only d0 at ts=1
60/// (d0, <opaque blob A>, 1, 1)
61/// // A txn to d0 (two blobs) and d1 (one blob) at ts=4
62/// (d0, <opaque blob B>, 4, 1)
63/// (d0, <opaque blob C>, 4, 1)
64/// (d1, <opaque blob D>, 4, 1)
65/// ```
66///
67/// However, the new commit is not yet readable until the txn apply has run,
68/// which is expected to be promptly done by the committer, except in the event
69/// of a crash. This, in ts order, moves the batch handles into the data shards
70/// with a [compare_and_append_batch] (similar to how the multi-worker
71/// persist_sink works).
72///
73/// [compare_and_append_batch]:
74///     mz_persist_client::write::WriteHandle::compare_and_append_batch
75///
76/// Once apply is run, we "tidy" the txns shard by retracting the update adding
77/// the batch. As a result, the contents of the txns shard at any given
78/// timestamp is exactly the set of outstanding apply work (plus registrations,
79/// see below).
80///
81/// Example (building on the above):
82///
83/// ```text
84/// // Tidy for the first txn at ts=3
85/// (d0, <opaque blob A>, 3, -1)
86/// // Tidy for the second txn (the timestamps can be different for each
87/// // retraction in a txn, but don't need to be)
88/// (d0, <opaque blob B>, 5, -1)
89/// (d0, <opaque blob C>, 6, -1)
90/// (d1, <opaque blob D>, 6, -1)
91/// ```
92///
93/// To make it easy to reason about exactly which data shards are registered in
94/// the txn set at any given moment, the data shard is added to the set with a
95/// `(ShardId, <empty>)` pair. The data may not be read before the timestamp of
96/// the update (which starts at the time it was initialized, but it may later be
97/// forwarded).
98///
99/// Example (building on both of the above):
100///
101/// ```text
102/// // d0 and d1 were both initialized before they were used above
103/// (d0, <empty>, 0, 1)
104/// (d1, <empty>, 2, 1)
105/// ```
106#[derive(Debug)]
107pub struct TxnsHandle<K: Codec, V: Codec, T, D, O = u64, C: TxnsCodec = TxnsCodecDefault> {
108    pub(crate) metrics: Arc<Metrics>,
109    pub(crate) txns_cache: TxnsCache<T, C>,
110    pub(crate) txns_write: WriteHandle<C::Key, C::Val, T, i64>,
111    pub(crate) txns_since: SinceHandle<C::Key, C::Val, T, i64, O>,
112    pub(crate) datas: DataHandles<K, V, T, D>,
113}
114
115impl<K, V, T, D, O, C> TxnsHandle<K, V, T, D, O, C>
116where
117    K: Debug + Codec,
118    V: Debug + Codec,
119    T: Timestamp + Lattice + TotalOrder + StepForward + Codec64 + Sync,
120    D: Debug + Semigroup + Ord + Codec64 + Send + Sync,
121    O: Opaque + Debug + Codec64,
122    C: TxnsCodec,
123{
124    /// Returns a [TxnsHandle] committing to the given txn shard.
125    ///
126    /// `txns_id` identifies which shard will be used as the txns WAL. MZ will
127    /// likely have one of these per env, used by all processes and the same
128    /// across restarts.
129    ///
130    /// This also does any (idempotent) initialization work: i.e. ensures that
131    /// the txn shard is readable at `init_ts` by appending an empty batch, if
132    /// necessary.
133    pub async fn open(
134        init_ts: T,
135        client: PersistClient,
136        dyncfgs: ConfigSet,
137        metrics: Arc<Metrics>,
138        txns_id: ShardId,
139    ) -> Self {
140        let (txns_key_schema, txns_val_schema) = C::schemas();
141        let (mut txns_write, txns_read) = client
142            .open(
143                txns_id,
144                Arc::new(txns_key_schema),
145                Arc::new(txns_val_schema),
146                Diagnostics {
147                    shard_name: "txns".to_owned(),
148                    handle_purpose: "commit txns".to_owned(),
149                },
150                USE_CRITICAL_SINCE_TXN.get(client.dyncfgs()),
151            )
152            .await
153            .expect("txns schema shouldn't change");
154        let txns_since = client
155            .open_critical_since(
156                txns_id,
157                // TODO: We likely need to use a different critical reader id
158                // for this if we want to be able to introspect it via SQL.
159                PersistClient::CONTROLLER_CRITICAL_SINCE,
160                Diagnostics {
161                    shard_name: "txns".to_owned(),
162                    handle_purpose: "commit txns".to_owned(),
163                },
164            )
165            .await
166            .expect("txns schema shouldn't change");
167        let txns_cache = TxnsCache::init(init_ts, txns_read, &mut txns_write).await;
168        TxnsHandle {
169            metrics,
170            txns_cache,
171            txns_write,
172            txns_since,
173            datas: DataHandles {
174                dyncfgs,
175                client: Arc::new(client),
176                data_write_for_apply: BTreeMap::new(),
177                data_write_for_commit: BTreeMap::new(),
178            },
179        }
180    }
181
182    /// Returns a new, empty transaction that can involve the data shards
183    /// registered with this handle.
184    pub fn begin(&self) -> Txn<K, V, T, D> {
185        // TODO: This is a method on the handle because we'll need WriteHandles
186        // once we start spilling to s3.
187        Txn::new()
188    }
189
190    /// Registers data shards for use with this txn set.
191    ///
192    /// A registration entry is written to the txn shard. If it is not possible
193    /// to register the data at the requested time, an Err will be returned with
194    /// the minimum time the data shards could be registered.
195    ///
196    /// This method is idempotent. Data shards currently registered at
197    /// `register_ts` will not be registered a second time. Specifically, this
198    /// method will return success when the most recent register ts `R` is
199    /// less_equal to `register_ts` AND there is no forget ts between `R` and
200    /// `register_ts`.
201    ///
202    /// As a side effect all txns <= register_ts are applied, including the
203    /// registration itself.
204    ///
205    /// **WARNING!** While a data shard is registered to the txn set, writing to
206    /// it directly (i.e. using a WriteHandle instead of the TxnHandle,
207    /// registering it with another txn shard) will lead to incorrectness,
208    /// undefined behavior, and (potentially sticky) panics.
209    #[instrument(level = "debug", fields(ts = ?register_ts))]
210    pub async fn register(
211        &mut self,
212        register_ts: T,
213        data_writes: impl IntoIterator<Item = WriteHandle<K, V, T, D>>,
214    ) -> Result<Tidy, T> {
215        let op = &Arc::clone(&self.metrics).register;
216        op.run(async {
217            let data_writes = data_writes.into_iter().collect::<Vec<_>>();
218            let updates = data_writes
219                .iter()
220                .map(|data_write| {
221                    let data_id = data_write.shard_id();
222                    let entry = TxnsEntry::Register(data_id, T::encode(&register_ts));
223                    (data_id, C::encode(entry))
224                })
225                .collect::<Vec<_>>();
226            let data_ids_debug = || {
227                data_writes
228                    .iter()
229                    .map(|x| format!("{:.9}", x.shard_id().to_string()))
230                    .collect::<Vec<_>>()
231                    .join(" ")
232            };
233
234            let mut txns_upper = self
235                .txns_write
236                .shared_upper()
237                .into_option()
238                .expect("txns should not be closed");
239            loop {
240                txns_upper = self.txns_cache.update_ge(&txns_upper).await.clone();
241                // Figure out which are still unregistered as of `txns_upper`. Below
242                // we write conditionally on the upper being what we expect so than
243                // we can re-run this if anything changes from underneath us.
244                let updates = updates
245                    .iter()
246                    .flat_map(|(data_id, (key, val))| {
247                        let registered =
248                            self.txns_cache.registered_at_progress(data_id, &txns_upper);
249                        (!registered).then_some(((key, val), &register_ts, 1))
250                    })
251                    .collect::<Vec<_>>();
252                // If the txns_upper has passed register_ts, we can no longer write.
253                if register_ts < txns_upper {
254                    debug!(
255                        "txns register {} at {:?} mismatch current={:?}",
256                        data_ids_debug(),
257                        register_ts,
258                        txns_upper,
259                    );
260                    return Err(txns_upper);
261                }
262
263                let res = crate::small_caa(
264                    || format!("txns register {}", data_ids_debug()),
265                    &mut self.txns_write,
266                    &updates,
267                    txns_upper,
268                    register_ts.step_forward(),
269                )
270                .await;
271                match res {
272                    Ok(()) => {
273                        debug!(
274                            "txns register {} at {:?} success",
275                            data_ids_debug(),
276                            register_ts
277                        );
278                        break;
279                    }
280                    Err(new_txns_upper) => {
281                        self.metrics.register.retry_count.inc();
282                        txns_upper = new_txns_upper;
283                        continue;
284                    }
285                }
286            }
287            for data_write in data_writes {
288                let new_schema_id = data_write.schema_id();
289
290                // If we already have a write handle for a newer version of a table, don't replace
291                // it! Currently we only support adding columns to tables with a default value, so
292                // the latest/newest schema will always be the most complete.
293                //
294                // TODO(alter_table): Revist when we support dropping columns.
295                match self.datas.data_write_for_commit.get(&data_write.shard_id()) {
296                    None => {
297                        self.datas
298                            .data_write_for_commit
299                            .insert(data_write.shard_id(), DataWriteCommit(data_write));
300                    }
301                    Some(previous) => {
302                        match (previous.schema_id(), new_schema_id) {
303                            (Some(previous_id), None) => {
304                                mz_ore::soft_panic_or_log!(
305                                    "tried registering a WriteHandle replacing one with a SchemaId prev_schema_id: {:?} shard_id: {:?}",
306                                    previous_id,
307                                    previous.shard_id(),
308                                );
309                            },
310                            (Some(previous_id), Some(new_id)) if previous_id > new_id => {
311                                mz_ore::soft_panic_or_log!(
312                                    "tried registering a WriteHandle with an older SchemaId prev_schema_id: {:?} new_schema_id: {:?} shard_id: {:?}",
313                                    previous_id,
314                                    new_id,
315                                    previous.shard_id(),
316                                );
317                            },
318                            (previous_schema_id, new_schema_id) => {
319                                if previous_schema_id.is_none() && new_schema_id.is_none() {
320                                    tracing::warn!("replacing WriteHandle without any SchemaIds to reason about");
321                                } else {
322                                    tracing::info!(?previous_schema_id, ?new_schema_id, shard_id = ?previous.shard_id(), "replacing WriteHandle");
323                                }
324                                self.datas.data_write_for_commit.insert(data_write.shard_id(), DataWriteCommit(data_write));
325                            }
326                        }
327                    }
328                }
329            }
330            let tidy = self.apply_le(&register_ts).await;
331
332            Ok(tidy)
333        })
334        .await
335    }
336
337    /// Removes data shards from use with this txn set.
338    ///
339    /// The registration entry written to the txn shard is retracted. If it is
340    /// not possible to forget the data shard at the requested time, an Err will
341    /// be returned with the minimum time the data shards could be forgotten.
342    ///
343    /// This method is idempotent. Data shards currently forgotten at
344    /// `forget_ts` will not be forgotten a second time. Specifically, this
345    /// method will return success when the most recent forget ts (if any) `F`
346    /// is less_equal to `forget_ts` AND there is no register ts between `F` and
347    /// `forget_ts`.
348    ///
349    /// As a side effect all txns <= forget_ts are applied, including the
350    /// forget itself.
351    ///
352    /// **WARNING!** While a data shard is registered to the txn set, writing to
353    /// it directly (i.e. using a WriteHandle instead of the TxnHandle,
354    /// registering it with another txn shard) will lead to incorrectness,
355    /// undefined behavior, and (potentially sticky) panics.
356    #[instrument(level = "debug", fields(ts = ?forget_ts))]
357    pub async fn forget(
358        &mut self,
359        forget_ts: T,
360        data_ids: impl IntoIterator<Item = ShardId>,
361    ) -> Result<Tidy, T> {
362        let op = &Arc::clone(&self.metrics).forget;
363        op.run(async {
364            let data_ids = data_ids.into_iter().collect::<Vec<_>>();
365            let mut txns_upper = self
366                .txns_write
367                .shared_upper()
368                .into_option()
369                .expect("txns should not be closed");
370            loop {
371                txns_upper = self.txns_cache.update_ge(&txns_upper).await.clone();
372
373                let data_ids_debug = || {
374                    data_ids
375                        .iter()
376                        .map(|x| format!("{:.9}", x.to_string()))
377                        .collect::<Vec<_>>()
378                        .join(" ")
379                };
380                let updates = data_ids
381                    .iter()
382                    // Never registered or already forgotten. This could change in
383                    // `[txns_upper, forget_ts]` (due to races) so close off that
384                    // interval before returning, just don't write any updates.
385                    .filter(|data_id| self.txns_cache.registered_at_progress(data_id, &txns_upper))
386                    .map(|data_id| C::encode(TxnsEntry::Register(*data_id, T::encode(&forget_ts))))
387                    .collect::<Vec<_>>();
388                let updates = updates
389                    .iter()
390                    .map(|(key, val)| ((key, val), &forget_ts, -1))
391                    .collect::<Vec<_>>();
392
393                // If the txns_upper has passed forget_ts, we can no longer write.
394                if forget_ts < txns_upper {
395                    debug!(
396                        "txns forget {} at {:?} mismatch current={:?}",
397                        data_ids_debug(),
398                        forget_ts,
399                        txns_upper,
400                    );
401                    return Err(txns_upper);
402                }
403
404                // Ensure the latest writes for each shard has been applied, so we don't run into
405                // any issues trying to apply it later.
406                {
407                    let data_ids: HashSet<_> = data_ids.iter().cloned().collect();
408                    let data_latest_unapplied = self
409                        .txns_cache
410                        .unapplied_batches
411                        .values()
412                        .rev()
413                        .find(|(x, _, _)| data_ids.contains(x));
414                    if let Some((_, _, latest_write)) = data_latest_unapplied {
415                        debug!(
416                            "txns forget {} applying latest write {:?}",
417                            data_ids_debug(),
418                            latest_write,
419                        );
420                        let latest_write = latest_write.clone();
421                        let _tidy = self.apply_le(&latest_write).await;
422                    }
423                }
424                let res = crate::small_caa(
425                    || format!("txns forget {}", data_ids_debug()),
426                    &mut self.txns_write,
427                    &updates,
428                    txns_upper,
429                    forget_ts.step_forward(),
430                )
431                .await;
432                match res {
433                    Ok(()) => {
434                        debug!(
435                            "txns forget {} at {:?} success",
436                            data_ids_debug(),
437                            forget_ts
438                        );
439                        break;
440                    }
441                    Err(new_txns_upper) => {
442                        self.metrics.forget.retry_count.inc();
443                        txns_upper = new_txns_upper;
444                        continue;
445                    }
446                }
447            }
448
449            // Note: Ordering here matters, we want to generate the Tidy work _before_ removing the
450            // handle because the work will create a handle to the shard.
451            let tidy = self.apply_le(&forget_ts).await;
452            for data_id in &data_ids {
453                self.datas.data_write_for_commit.remove(data_id);
454            }
455
456            Ok(tidy)
457        })
458        .await
459    }
460
461    /// Forgets, at the given timestamp, every data shard that is registered.
462    /// Returns the ids of the forgotten shards. See [Self::forget].
463    #[instrument(level = "debug", fields(ts = ?forget_ts))]
464    pub async fn forget_all(&mut self, forget_ts: T) -> Result<(Vec<ShardId>, Tidy), T> {
465        let op = &Arc::clone(&self.metrics).forget_all;
466        op.run(async {
467            let mut txns_upper = self
468                .txns_write
469                .shared_upper()
470                .into_option()
471                .expect("txns should not be closed");
472            let registered = loop {
473                txns_upper = self.txns_cache.update_ge(&txns_upper).await.clone();
474
475                let registered = self.txns_cache.all_registered_at_progress(&txns_upper);
476                let data_ids_debug = || {
477                    registered
478                        .iter()
479                        .map(|x| format!("{:.9}", x.to_string()))
480                        .collect::<Vec<_>>()
481                        .join(" ")
482                };
483                let updates = registered
484                    .iter()
485                    .map(|data_id| {
486                        C::encode(crate::TxnsEntry::Register(*data_id, T::encode(&forget_ts)))
487                    })
488                    .collect::<Vec<_>>();
489                let updates = updates
490                    .iter()
491                    .map(|(key, val)| ((key, val), &forget_ts, -1))
492                    .collect::<Vec<_>>();
493
494                // If the txns_upper has passed forget_ts, we can no longer write.
495                if forget_ts < txns_upper {
496                    debug!(
497                        "txns forget_all {} at {:?} mismatch current={:?}",
498                        data_ids_debug(),
499                        forget_ts,
500                        txns_upper,
501                    );
502                    return Err(txns_upper);
503                }
504
505                // Ensure the latest write has been applied, so we don't run into
506                // any issues trying to apply it later.
507                //
508                // NB: It's _very_ important for correctness to get this from the
509                // unapplied batches (which compact themselves naturally) and not
510                // from the writes (which are artificially compacted based on when
511                // we need reads for).
512                let data_latest_unapplied = self.txns_cache.unapplied_batches.values().last();
513                if let Some((_, _, latest_write)) = data_latest_unapplied {
514                    debug!(
515                        "txns forget_all {} applying latest write {:?}",
516                        data_ids_debug(),
517                        latest_write,
518                    );
519                    let latest_write = latest_write.clone();
520                    let _tidy = self.apply_le(&latest_write).await;
521                }
522                let res = crate::small_caa(
523                    || format!("txns forget_all {}", data_ids_debug()),
524                    &mut self.txns_write,
525                    &updates,
526                    txns_upper,
527                    forget_ts.step_forward(),
528                )
529                .await;
530                match res {
531                    Ok(()) => {
532                        debug!(
533                            "txns forget_all {} at {:?} success",
534                            data_ids_debug(),
535                            forget_ts
536                        );
537                        break registered;
538                    }
539                    Err(new_txns_upper) => {
540                        self.metrics.forget_all.retry_count.inc();
541                        txns_upper = new_txns_upper;
542                        continue;
543                    }
544                }
545            };
546
547            for data_id in registered.iter() {
548                self.datas.data_write_for_commit.remove(data_id);
549            }
550            let tidy = self.apply_le(&forget_ts).await;
551
552            Ok((registered, tidy))
553        })
554        .await
555    }
556
557    /// "Applies" all committed txns <= the given timestamp, ensuring that reads
558    /// at that timestamp will not block.
559    ///
560    /// In the common case, the txn committer will have done this work and this
561    /// method will be a no-op, but it is not guaranteed. In the event of a
562    /// crash or race, this does whatever persist writes are necessary (and
563    /// returns the resulting maintenance work), which could be significant.
564    ///
565    /// If the requested timestamp has not yet been written, this could block
566    /// for an unbounded amount of time.
567    ///
568    /// This method is idempotent.
569    #[instrument(level = "debug", fields(ts = ?ts))]
570    pub async fn apply_le(&mut self, ts: &T) -> Tidy {
571        let op = &self.metrics.apply_le;
572        op.run(async {
573            debug!("apply_le {:?}", ts);
574            let _ = self.txns_cache.update_gt(ts).await;
575            self.txns_cache.update_gauges(&self.metrics);
576
577            let mut unapplied_by_data = BTreeMap::<_, Vec<_>>::new();
578            for (data_id, unapplied, unapplied_ts) in self.txns_cache.unapplied() {
579                if ts < unapplied_ts {
580                    break;
581                }
582                unapplied_by_data
583                    .entry(*data_id)
584                    .or_default()
585                    .push((unapplied, unapplied_ts));
586            }
587
588            let retractions = FuturesUnordered::new();
589            for (data_id, unapplied) in unapplied_by_data {
590                let mut data_write = self.datas.take_write_for_apply(&data_id).await;
591                retractions.push(async move {
592                    let mut ret = Vec::new();
593                    for (unapplied, unapplied_ts) in unapplied {
594                        match unapplied {
595                            Unapplied::RegisterForget => {
596                                let () = crate::empty_caa(
597                                    || {
598                                        format!(
599                                            "data {:.9} register/forget fill",
600                                            data_id.to_string()
601                                        )
602                                    },
603                                    &mut data_write,
604                                    unapplied_ts.clone(),
605                                )
606                                .await;
607                            }
608                            Unapplied::Batch(batch_raws) => {
609                                let batch_raws = batch_raws
610                                    .into_iter()
611                                    .map(|batch_raw| batch_raw.as_slice())
612                                    .collect();
613                                crate::apply_caa(
614                                    &mut data_write,
615                                    &batch_raws,
616                                    unapplied_ts.clone(),
617                                )
618                                .await;
619                                for batch_raw in batch_raws {
620                                    // NB: Protos are not guaranteed to exactly roundtrip the
621                                    // encoded bytes, so we intentionally use the raw batch so that
622                                    // it definitely retracts.
623                                    ret.push((
624                                        batch_raw.to_vec(),
625                                        (T::encode(unapplied_ts), data_id),
626                                    ));
627                                }
628                            }
629                        }
630                    }
631                    (data_write, ret)
632                });
633            }
634            let retractions = retractions.collect::<Vec<_>>().await;
635            let retractions = retractions
636                .into_iter()
637                .flat_map(|(data_write, retractions)| {
638                    self.datas.put_write_for_apply(data_write);
639                    retractions
640                })
641                .collect();
642
643            // Remove all the applied registers.
644            self.txns_cache.mark_register_applied(ts);
645
646            debug!("apply_le {:?} success", ts);
647            Tidy { retractions }
648        })
649        .await
650    }
651
652    /// Commits the tidy work at the given time.
653    ///
654    /// Mostly a helper to make it obvious that we can throw away the apply work
655    /// (and not get into an infinite cycle of tidy->apply->tidy).
656    #[cfg(test)]
657    pub async fn tidy_at(&mut self, tidy_ts: T, tidy: Tidy) -> Result<(), T> {
658        debug!("tidy at {:?}", tidy_ts);
659
660        let mut txn = self.begin();
661        txn.tidy(tidy);
662        // We just constructed this txn, so it couldn't have committed any
663        // batches, and thus there's nothing to apply. We're free to throw it
664        // away.
665        let apply = txn.commit_at(self, tidy_ts.clone()).await?;
666        assert!(apply.is_empty());
667
668        debug!("tidy at {:?} success", tidy_ts);
669        Ok(())
670    }
671
672    /// Allows compaction to the txns shard as well as internal representations,
673    /// losing the ability to answer queries about times less_than since_ts.
674    ///
675    /// In practice, this will likely only be called from the singleton
676    /// controller process.
677    pub async fn compact_to(&mut self, mut since_ts: T) {
678        let op = &self.metrics.compact_to;
679        op.run(async {
680            tracing::debug!("compact_to {:?}", since_ts);
681            let _ = self.txns_cache.update_gt(&since_ts).await;
682
683            // NB: A critical invariant for how this all works is that we never
684            // allow the since of the txns shard to pass any unapplied writes, so
685            // reduce it as necessary.
686            let min_unapplied_ts = self.txns_cache.min_unapplied_ts();
687            if min_unapplied_ts < &since_ts {
688                since_ts.clone_from(min_unapplied_ts);
689            }
690            crate::cads::<T, O, C>(&mut self.txns_since, since_ts).await;
691        })
692        .await
693    }
694
695    /// Returns the [ShardId] of the txns shard.
696    pub fn txns_id(&self) -> ShardId {
697        self.txns_write.shard_id()
698    }
699
700    /// Returns the [TxnsCache] used by this handle.
701    pub fn read_cache(&self) -> &TxnsCache<T, C> {
702        &self.txns_cache
703    }
704}
705
706/// A token representing maintenance writes (in particular, retractions) to the
707/// txns shard.
708///
709/// This can be written on its own with `TxnsHandle::tidy_at` or sidecar'd into
710/// a normal txn with [Txn::tidy].
711#[derive(Debug, Default)]
712pub struct Tidy {
713    pub(crate) retractions: BTreeMap<Vec<u8>, ([u8; 8], ShardId)>,
714}
715
716impl Tidy {
717    /// Merges the work represented by the other tidy into this one.
718    pub fn merge(&mut self, other: Tidy) {
719        self.retractions.extend(other.retractions)
720    }
721}
722
723/// A helper to make a more targeted mutable borrow of self.
724#[derive(Debug)]
725pub(crate) struct DataHandles<K: Codec, V: Codec, T, D> {
726    pub(crate) dyncfgs: ConfigSet,
727    pub(crate) client: Arc<PersistClient>,
728    /// See [DataWriteApply].
729    ///
730    /// This is lazily populated with the set of shards touched by `apply_le`.
731    data_write_for_apply: BTreeMap<ShardId, DataWriteApply<K, V, T, D>>,
732    /// See [DataWriteCommit].
733    ///
734    /// This contains the set of data shards registered but not yet forgotten
735    /// with this particular write handle.
736    ///
737    /// NB: In the common case, this and `_for_apply` will contain the same set
738    /// of shards, but this is not required. A shard can be in either and not
739    /// the other.
740    data_write_for_commit: BTreeMap<ShardId, DataWriteCommit<K, V, T, D>>,
741}
742
743impl<K, V, T, D> DataHandles<K, V, T, D>
744where
745    K: Debug + Codec,
746    V: Debug + Codec,
747    T: Timestamp + Lattice + TotalOrder + Codec64 + Sync,
748    D: Semigroup + Ord + Codec64 + Send + Sync,
749{
750    async fn open_data_write_for_apply(&self, data_id: ShardId) -> DataWriteApply<K, V, T, D> {
751        let diagnostics = Diagnostics::from_purpose("txn data");
752        let schemas = self
753            .client
754            .latest_schema::<K, V, T, D>(data_id, diagnostics.clone())
755            .await
756            .expect("codecs have not changed");
757        let (key_schema, val_schema) = match schemas {
758            Some((_, key_schema, val_schema)) => (Arc::new(key_schema), Arc::new(val_schema)),
759            // - For new shards we will always have at least one schema
760            //   registered by the time we reach this point, because that
761            //   happens at txn-registration time.
762            // - For pre-existing shards, every txns shard will have had
763            //   open_writer called on it at least once in the previous release,
764            //   so the schema should exist.
765            None => unreachable!("data shard {} should have a schema", data_id),
766        };
767        let wrapped = self
768            .client
769            .open_writer(data_id, key_schema, val_schema, diagnostics)
770            .await
771            .expect("schema shouldn't change");
772        DataWriteApply {
773            apply_ensure_schema_match: APPLY_ENSURE_SCHEMA_MATCH.handle(&self.dyncfgs),
774            client: Arc::clone(&self.client),
775            wrapped,
776        }
777    }
778
779    pub(crate) async fn take_write_for_apply(
780        &mut self,
781        data_id: &ShardId,
782    ) -> DataWriteApply<K, V, T, D> {
783        if let Some(data_write) = self.data_write_for_apply.remove(data_id) {
784            return data_write;
785        }
786        self.open_data_write_for_apply(*data_id).await
787    }
788
789    pub(crate) fn put_write_for_apply(&mut self, data_write: DataWriteApply<K, V, T, D>) {
790        self.data_write_for_apply
791            .insert(data_write.shard_id(), data_write);
792    }
793
794    pub(crate) fn take_write_for_commit(
795        &mut self,
796        data_id: &ShardId,
797    ) -> Option<DataWriteCommit<K, V, T, D>> {
798        self.data_write_for_commit.remove(data_id)
799    }
800
801    pub(crate) fn put_write_for_commit(&mut self, data_write: DataWriteCommit<K, V, T, D>) {
802        let prev = self
803            .data_write_for_commit
804            .insert(data_write.shard_id(), data_write);
805        assert!(prev.is_none());
806    }
807}
808
809/// A newtype wrapper around [WriteHandle] indicating that it has a real schema
810/// registered by the user.
811///
812/// The txn-wal user declares which schema they'd like to use for committing
813/// batches by passing it in (as part of the WriteHandle) in the call to
814/// register. This must be used to encode any new batches written. The wrapper
815/// helps us from accidentally mixing up the WriteHandles that we internally
816/// invent for applying the batches (which use a schema matching the one
817/// declared in the batch).
818#[derive(Debug)]
819pub(crate) struct DataWriteCommit<K: Codec, V: Codec, T, D>(pub(crate) WriteHandle<K, V, T, D>);
820
821impl<K: Codec, V: Codec, T, D> Deref for DataWriteCommit<K, V, T, D> {
822    type Target = WriteHandle<K, V, T, D>;
823
824    fn deref(&self) -> &Self::Target {
825        &self.0
826    }
827}
828
829impl<K: Codec, V: Codec, T, D> DerefMut for DataWriteCommit<K, V, T, D> {
830    fn deref_mut(&mut self) -> &mut Self::Target {
831        &mut self.0
832    }
833}
834
835/// A newtype wrapper around [WriteHandle] indicating that it can alter the
836/// schema its using to match the one in the batches being appended.
837///
838/// When a batch is committed to txn-wal, it contains metadata about which
839/// schemas were used to encode the data in it. Txn-wal then uses this info to
840/// make sure that in [TxnsHandle::apply_le], that the `compare_and_append` call
841/// happens on a handle with the same schema. This is accomplished by querying
842/// the persist schema registry.
843#[derive(Debug)]
844pub(crate) struct DataWriteApply<K: Codec, V: Codec, T, D> {
845    client: Arc<PersistClient>,
846    apply_ensure_schema_match: ConfigValHandle<bool>,
847    pub(crate) wrapped: WriteHandle<K, V, T, D>,
848}
849
850impl<K: Codec, V: Codec, T, D> Deref for DataWriteApply<K, V, T, D> {
851    type Target = WriteHandle<K, V, T, D>;
852
853    fn deref(&self) -> &Self::Target {
854        &self.wrapped
855    }
856}
857
858impl<K: Codec, V: Codec, T, D> DerefMut for DataWriteApply<K, V, T, D> {
859    fn deref_mut(&mut self) -> &mut Self::Target {
860        &mut self.wrapped
861    }
862}
863
864pub(crate) const APPLY_ENSURE_SCHEMA_MATCH: Config<bool> = Config::new(
865    "txn_wal_apply_ensure_schema_match",
866    true,
867    "CYA to skip updating write handle to batch schema in apply",
868);
869
870fn at_most_one_schema(
871    schemas: impl Iterator<Item = SchemaId>,
872) -> Result<Option<SchemaId>, (SchemaId, SchemaId)> {
873    let mut schema = None;
874    for s in schemas {
875        match schema {
876            None => schema = Some(s),
877            Some(x) if s != x => return Err((s, x)),
878            Some(_) => continue,
879        }
880    }
881    Ok(schema)
882}
883
884impl<K, V, T, D> DataWriteApply<K, V, T, D>
885where
886    K: Debug + Codec,
887    V: Debug + Codec,
888    T: Timestamp + Lattice + TotalOrder + Codec64 + Sync,
889    D: Semigroup + Ord + Codec64 + Send + Sync,
890{
891    pub(crate) async fn maybe_replace_with_batch_schema(&mut self, batches: &[Batch<K, V, T, D>]) {
892        // TODO: Remove this once everything is rolled out and we're sure it's
893        // not going to cause any issues.
894        if !self.apply_ensure_schema_match.get() {
895            return;
896        }
897        let batch_schema = at_most_one_schema(batches.iter().flat_map(|x| x.schemas()));
898        let batch_schema = batch_schema.unwrap_or_else(|_| {
899            panic!(
900                "txn-wal uses at most one schema to commit batches, got: {:?}",
901                batches.iter().flat_map(|x| x.schemas()).collect::<Vec<_>>()
902            )
903        });
904        let (batch_schema, handle_schema) = match (batch_schema, self.wrapped.schema_id()) {
905            (Some(batch_schema), Some(handle_schema)) if batch_schema != handle_schema => {
906                (batch_schema, handle_schema)
907            }
908            _ => return,
909        };
910
911        let data_id = self.shard_id();
912        let diagnostics = Diagnostics::from_purpose("txn data");
913        let (key_schema, val_schema) = self
914            .client
915            .get_schema::<K, V, T, D>(data_id, batch_schema, diagnostics.clone())
916            .await
917            .expect("codecs shouldn't change")
918            .expect("id must have been registered to create this batch");
919        let new_data_write = self
920            .client
921            .open_writer(
922                self.shard_id(),
923                Arc::new(key_schema),
924                Arc::new(val_schema),
925                diagnostics,
926            )
927            .await
928            .expect("codecs shouldn't change");
929        tracing::info!(
930            "updated {} write handle from {} to {} to apply batches",
931            data_id,
932            handle_schema,
933            batch_schema
934        );
935        assert_eq!(new_data_write.schema_id(), Some(batch_schema));
936        self.wrapped = new_data_write;
937    }
938}
939
940#[cfg(test)]
941mod tests {
942    use std::time::{Duration, UNIX_EPOCH};
943
944    use differential_dataflow::Hashable;
945    use futures::future::BoxFuture;
946    use mz_ore::assert_none;
947    use mz_ore::cast::CastFrom;
948    use mz_ore::collections::CollectionExt;
949    use mz_ore::metrics::MetricsRegistry;
950    use mz_persist_client::PersistLocation;
951    use mz_persist_client::cache::PersistClientCache;
952    use mz_persist_client::cfg::RetryParameters;
953    use rand::rngs::SmallRng;
954    use rand::{RngCore, SeedableRng};
955    use timely::progress::Antichain;
956    use tokio::sync::oneshot;
957    use tracing::{Instrument, info, info_span};
958
959    use crate::operator::DataSubscribe;
960    use crate::tests::{CommitLog, reader, write_directly, writer};
961
962    use super::*;
963
964    impl TxnsHandle<String, (), u64, i64, u64, TxnsCodecDefault> {
965        pub(crate) async fn expect_open(client: PersistClient) -> Self {
966            Self::expect_open_id(client, ShardId::new()).await
967        }
968
969        pub(crate) async fn expect_open_id(client: PersistClient, txns_id: ShardId) -> Self {
970            let dyncfgs = crate::all_dyncfgs(client.dyncfgs().clone());
971            Self::open(
972                0,
973                client,
974                dyncfgs,
975                Arc::new(Metrics::new(&MetricsRegistry::new())),
976                txns_id,
977            )
978            .await
979        }
980
981        pub(crate) fn new_log(&self) -> CommitLog {
982            CommitLog::new((*self.datas.client).clone(), self.txns_id())
983        }
984
985        pub(crate) async fn expect_register(&mut self, register_ts: u64) -> ShardId {
986            self.expect_registers(register_ts, 1).await.into_element()
987        }
988
989        pub(crate) async fn expect_registers(
990            &mut self,
991            register_ts: u64,
992            amount: usize,
993        ) -> Vec<ShardId> {
994            let data_ids: Vec<_> = (0..amount).map(|_| ShardId::new()).collect();
995            let mut writers = Vec::new();
996            for data_id in &data_ids {
997                writers.push(writer(&self.datas.client, *data_id).await);
998            }
999            self.register(register_ts, writers).await.unwrap();
1000            data_ids
1001        }
1002
1003        pub(crate) async fn expect_commit_at(
1004            &mut self,
1005            commit_ts: u64,
1006            data_id: ShardId,
1007            keys: &[&str],
1008            log: &CommitLog,
1009        ) -> Tidy {
1010            let mut txn = self.begin();
1011            for key in keys {
1012                txn.write(&data_id, (*key).into(), (), 1).await;
1013            }
1014            let tidy = txn
1015                .commit_at(self, commit_ts)
1016                .await
1017                .unwrap()
1018                .apply(self)
1019                .await;
1020            for key in keys {
1021                log.record((data_id, (*key).into(), commit_ts, 1));
1022            }
1023            tidy
1024        }
1025    }
1026
1027    #[mz_ore::test(tokio::test)]
1028    #[cfg_attr(miri, ignore)] // too slow
1029    async fn register_at() {
1030        let client = PersistClient::new_for_tests().await;
1031        let mut txns = TxnsHandle::expect_open(client.clone()).await;
1032        let log = txns.new_log();
1033        let d0 = txns.expect_register(2).await;
1034
1035        // Register a second time is a no-op (idempotent).
1036        txns.register(3, [writer(&client, d0).await]).await.unwrap();
1037
1038        // Cannot register a new data shard at an already closed off time. An
1039        // error is returned with the first time that a registration would
1040        // succeed.
1041        let d1 = ShardId::new();
1042        assert_eq!(
1043            txns.register(2, [writer(&client, d1).await])
1044                .await
1045                .unwrap_err(),
1046            4
1047        );
1048
1049        // Can still register after txns have been committed.
1050        txns.expect_commit_at(4, d0, &["foo"], &log).await;
1051        txns.register(5, [writer(&client, d1).await]).await.unwrap();
1052
1053        // We can also register some new and some already registered shards.
1054        let d2 = ShardId::new();
1055        txns.register(6, [writer(&client, d0).await, writer(&client, d2).await])
1056            .await
1057            .unwrap();
1058
1059        let () = log.assert_snapshot(d0, 6).await;
1060        let () = log.assert_snapshot(d1, 6).await;
1061    }
1062
1063    /// A sanity check that CommitLog catches an incorrect usage (proxy for a
1064    /// bug that looks like an incorrect usage).
1065    #[mz_ore::test(tokio::test)]
1066    #[cfg_attr(miri, ignore)] // too slow
1067    #[should_panic(expected = "left: [(\"foo\", 2, 1)]\n right: [(\"foo\", 2, 2)]")]
1068    async fn incorrect_usage_register_write_same_time() {
1069        let client = PersistClient::new_for_tests().await;
1070        let mut txns = TxnsHandle::expect_open(client.clone()).await;
1071        let log = txns.new_log();
1072        let d0 = txns.expect_register(1).await;
1073        let mut d0_write = writer(&client, d0).await;
1074
1075        // Commit a write at ts 2...
1076        let mut txn = txns.begin_test();
1077        txn.write(&d0, "foo".into(), (), 1).await;
1078        let apply = txn.commit_at(&mut txns, 2).await.unwrap();
1079        log.record_txn(2, &txn);
1080        // ... and (incorrectly) also write to the shard normally at ts 2.
1081        let () = d0_write
1082            .compare_and_append(
1083                &[(("foo".to_owned(), ()), 2, 1)],
1084                Antichain::from_elem(2),
1085                Antichain::from_elem(3),
1086            )
1087            .await
1088            .unwrap()
1089            .unwrap();
1090        log.record((d0, "foo".into(), 2, 1));
1091        apply.apply(&mut txns).await;
1092
1093        // Verify that CommitLog catches this.
1094        log.assert_snapshot(d0, 2).await;
1095    }
1096
1097    #[mz_ore::test(tokio::test)]
1098    #[cfg_attr(miri, ignore)] // too slow
1099    async fn forget_at() {
1100        let client = PersistClient::new_for_tests().await;
1101        let mut txns = TxnsHandle::expect_open(client.clone()).await;
1102        let log = txns.new_log();
1103
1104        // Can forget a data_shard that has not been registered.
1105        txns.forget(1, [ShardId::new()]).await.unwrap();
1106
1107        // Can forget multiple data_shards that have not been registered.
1108        txns.forget(2, (0..5).map(|_| ShardId::new()))
1109            .await
1110            .unwrap();
1111
1112        // Can forget a registered shard.
1113        let d0 = txns.expect_register(3).await;
1114        txns.forget(4, [d0]).await.unwrap();
1115
1116        // Can forget multiple registered shards.
1117        let ds = txns.expect_registers(5, 5).await;
1118        txns.forget(6, ds.clone()).await.unwrap();
1119
1120        // Forget is idempotent.
1121        txns.forget(7, [d0]).await.unwrap();
1122        txns.forget(8, ds.clone()).await.unwrap();
1123
1124        // Cannot forget at an already closed off time. An error is returned
1125        // with the first time that a registration would succeed.
1126        let d1 = txns.expect_register(9).await;
1127        assert_eq!(txns.forget(9, [d1]).await.unwrap_err(), 10);
1128
1129        // Write to txns and to d0 directly.
1130        let mut d0_write = writer(&client, d0).await;
1131        txns.expect_commit_at(10, d1, &["d1"], &log).await;
1132        let updates = [(("d0".to_owned(), ()), 10, 1)];
1133        d0_write
1134            .compare_and_append(&updates, d0_write.shared_upper(), Antichain::from_elem(11))
1135            .await
1136            .unwrap()
1137            .unwrap();
1138        log.record((d0, "d0".into(), 10, 1));
1139
1140        // Can register and forget an already registered and forgotten shard.
1141        txns.register(11, [writer(&client, d0).await])
1142            .await
1143            .unwrap();
1144        let mut forget_expected = vec![d0, d1];
1145        forget_expected.sort();
1146        assert_eq!(txns.forget_all(12).await.unwrap().0, forget_expected);
1147
1148        // Close shard to writes
1149        d0_write
1150            .compare_and_append_batch(&mut [], d0_write.shared_upper(), Antichain::new())
1151            .await
1152            .unwrap()
1153            .unwrap();
1154
1155        let () = log.assert_snapshot(d0, 12).await;
1156        let () = log.assert_snapshot(d1, 12).await;
1157
1158        for di in ds {
1159            let mut di_write = writer(&client, di).await;
1160
1161            // Close shards to writes
1162            di_write
1163                .compare_and_append_batch(&mut [], di_write.shared_upper(), Antichain::new())
1164                .await
1165                .unwrap()
1166                .unwrap();
1167
1168            let () = log.assert_snapshot(di, 8).await;
1169        }
1170    }
1171
1172    #[mz_ore::test(tokio::test)]
1173    #[cfg_attr(miri, ignore)] // too slow
1174    async fn register_forget() {
1175        async fn step_some_past(subs: &mut Vec<DataSubscribe>, ts: u64) {
1176            for (idx, sub) in subs.iter_mut().enumerate() {
1177                // Only step some of them to try to maximize edge cases.
1178                if usize::cast_from(ts) % (idx + 1) == 0 {
1179                    async {
1180                        info!("stepping sub {} past {}", idx, ts);
1181                        sub.step_past(ts).await;
1182                    }
1183                    .instrument(info_span!("sub", idx))
1184                    .await;
1185                }
1186            }
1187        }
1188
1189        let client = PersistClient::new_for_tests().await;
1190        let mut txns = TxnsHandle::expect_open(client.clone()).await;
1191        let log = txns.new_log();
1192        let d0 = ShardId::new();
1193        let mut d0_write = writer(&client, d0).await;
1194        let mut subs = Vec::new();
1195
1196        // Loop for a while doing the following:
1197        // - Write directly to some time before register ts
1198        // - Register
1199        // - Write via txns
1200        // - Forget
1201        //
1202        // After each step, make sure that a subscription started at that time
1203        // works and that all subscriptions can be stepped through the expected
1204        // timestamp.
1205        let mut ts = 0;
1206        while ts < 32 {
1207            subs.push(txns.read_cache().expect_subscribe(&client, d0, ts));
1208            ts += 1;
1209            info!("{} direct", ts);
1210            txns.begin().commit_at(&mut txns, ts).await.unwrap();
1211            write_directly(ts, &mut d0_write, &[&format!("d{}", ts)], &log).await;
1212            step_some_past(&mut subs, ts).await;
1213            if ts % 11 == 0 {
1214                txns.compact_to(ts).await;
1215            }
1216
1217            subs.push(txns.read_cache().expect_subscribe(&client, d0, ts));
1218            ts += 1;
1219            info!("{} register", ts);
1220            txns.register(ts, [writer(&client, d0).await])
1221                .await
1222                .unwrap();
1223            step_some_past(&mut subs, ts).await;
1224            if ts % 11 == 0 {
1225                txns.compact_to(ts).await;
1226            }
1227
1228            subs.push(txns.read_cache().expect_subscribe(&client, d0, ts));
1229            ts += 1;
1230            info!("{} txns", ts);
1231            txns.expect_commit_at(ts, d0, &[&format!("t{}", ts)], &log)
1232                .await;
1233            step_some_past(&mut subs, ts).await;
1234            if ts % 11 == 0 {
1235                txns.compact_to(ts).await;
1236            }
1237
1238            subs.push(txns.read_cache().expect_subscribe(&client, d0, ts));
1239            ts += 1;
1240            info!("{} forget", ts);
1241            txns.forget(ts, [d0]).await.unwrap();
1242            step_some_past(&mut subs, ts).await;
1243            if ts % 11 == 0 {
1244                txns.compact_to(ts).await;
1245            }
1246        }
1247
1248        // Check all the subscribes.
1249        for mut sub in subs.into_iter() {
1250            sub.step_past(ts).await;
1251            log.assert_eq(d0, sub.as_of, sub.progress(), sub.output().clone());
1252        }
1253    }
1254
1255    // Regression test for a bug encountered during initial development:
1256    // - task 0 commits to a data shard at ts T
1257    // - before task 0 can unblock T for reads, task 1 tries to register it at
1258    //   time T (ditto T+X), this does a caa of empty space, advancing the upper
1259    //   of the data shard to T+1
1260    // - task 0 attempts to caa in the batch, but finds out the upper is T+1 and
1261    //   assumes someone else did the work
1262    // - result: the write is lost
1263    #[mz_ore::test(tokio::test)]
1264    #[cfg_attr(miri, ignore)] // unsupported operation: returning ready events from epoll_wait is not yet implemented
1265    async fn race_data_shard_register_and_commit() {
1266        let client = PersistClient::new_for_tests().await;
1267        let mut txns = TxnsHandle::expect_open(client.clone()).await;
1268        let d0 = txns.expect_register(1).await;
1269
1270        let mut txn = txns.begin();
1271        txn.write(&d0, "foo".into(), (), 1).await;
1272        let commit_apply = txn.commit_at(&mut txns, 2).await.unwrap();
1273
1274        txns.register(3, [writer(&client, d0).await]).await.unwrap();
1275
1276        // Make sure that we can read empty at the register commit time even
1277        // before the txn commit apply.
1278        let actual = txns.txns_cache.expect_snapshot(&client, d0, 1).await;
1279        assert_eq!(actual, Vec::<String>::new());
1280
1281        commit_apply.apply(&mut txns).await;
1282        let actual = txns.txns_cache.expect_snapshot(&client, d0, 2).await;
1283        assert_eq!(actual, vec!["foo".to_owned()]);
1284    }
1285
1286    // A test that applies a batch of writes all at once.
1287    #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
1288    #[cfg_attr(miri, ignore)] // unsupported operation: returning ready events from epoll_wait is not yet implemented
1289    async fn apply_many_ts() {
1290        let client = PersistClient::new_for_tests().await;
1291        let mut txns = TxnsHandle::expect_open(client.clone()).await;
1292        let log = txns.new_log();
1293        let d0 = txns.expect_register(1).await;
1294
1295        for ts in 2..10 {
1296            let mut txn = txns.begin();
1297            txn.write(&d0, ts.to_string(), (), 1).await;
1298            let _apply = txn.commit_at(&mut txns, ts).await.unwrap();
1299            log.record((d0, ts.to_string(), ts, 1));
1300        }
1301        // This automatically runs the apply, which catches up all the previous
1302        // txns at once.
1303        txns.expect_commit_at(10, d0, &[], &log).await;
1304
1305        log.assert_snapshot(d0, 10).await;
1306    }
1307
1308    struct StressWorker {
1309        idx: usize,
1310        data_ids: Vec<ShardId>,
1311        txns: TxnsHandle<String, (), u64, i64>,
1312        log: CommitLog,
1313        tidy: Tidy,
1314        ts: u64,
1315        step: usize,
1316        rng: SmallRng,
1317        reads: Vec<(
1318            oneshot::Sender<u64>,
1319            ShardId,
1320            u64,
1321            mz_ore::task::JoinHandle<Vec<(String, u64, i64)>>,
1322        )>,
1323    }
1324
1325    impl StressWorker {
1326        pub async fn step(&mut self) {
1327            debug!(
1328                "stress {} step {} START ts={}",
1329                self.idx, self.step, self.ts
1330            );
1331            let data_id =
1332                self.data_ids[usize::cast_from(self.rng.next_u64()) % self.data_ids.len()];
1333            match self.rng.next_u64() % 6 {
1334                0 => self.write(data_id).await,
1335                // The register and forget impls intentionally don't switch on
1336                // whether it's already registered to stress idempotence.
1337                1 => self.register(data_id).await,
1338                2 => self.forget(data_id).await,
1339                3 => {
1340                    debug!("stress update {:.9} to {}", data_id.to_string(), self.ts);
1341                    let _ = self.txns.txns_cache.update_ge(&self.ts).await;
1342                }
1343                4 => self.start_read(data_id, true),
1344                5 => self.start_read(data_id, false),
1345                _ => unreachable!(""),
1346            }
1347            debug!("stress {} step {} DONE ts={}", self.idx, self.step, self.ts);
1348            self.step += 1;
1349        }
1350
1351        fn key(&self) -> String {
1352            format!("w{}s{}", self.idx, self.step)
1353        }
1354
1355        async fn registered_at_progress_ts(&mut self, data_id: ShardId) -> bool {
1356            self.ts = *self.txns.txns_cache.update_ge(&self.ts).await;
1357            self.txns
1358                .txns_cache
1359                .registered_at_progress(&data_id, &self.ts)
1360        }
1361
1362        // Writes to the given data shard, either via txns if it's registered or
1363        // directly if it's not.
1364        async fn write(&mut self, data_id: ShardId) {
1365            // Make sure to keep the registered_at_ts call _inside_ the retry
1366            // loop, because a data shard might switch between registered or not
1367            // registered as the loop advances through timestamps.
1368            self.retry_ts_err(&mut |w: &mut StressWorker| {
1369                Box::pin(async move {
1370                    if w.registered_at_progress_ts(data_id).await {
1371                        w.write_via_txns(data_id).await
1372                    } else {
1373                        w.write_direct(data_id).await
1374                    }
1375                })
1376            })
1377            .await
1378        }
1379
1380        async fn write_via_txns(&mut self, data_id: ShardId) -> Result<(), u64> {
1381            debug!(
1382                "stress write_via_txns {:.9} at {}",
1383                data_id.to_string(),
1384                self.ts
1385            );
1386            // HACK: Normally, we'd make sure that this particular handle had
1387            // registered the data shard before writing to it, but that would
1388            // consume a ts and isn't quite how we want `write_via_txns` to
1389            // work. Work around that by setting a write handle (with a schema
1390            // that we promise is correct) in the right place.
1391            if !self.txns.datas.data_write_for_commit.contains_key(&data_id) {
1392                let x = writer(&self.txns.datas.client, data_id).await;
1393                self.txns
1394                    .datas
1395                    .data_write_for_commit
1396                    .insert(data_id, DataWriteCommit(x));
1397            }
1398            let mut txn = self.txns.begin_test();
1399            txn.tidy(std::mem::take(&mut self.tidy));
1400            txn.write(&data_id, self.key(), (), 1).await;
1401            let apply = txn.commit_at(&mut self.txns, self.ts).await?;
1402            debug!(
1403                "log {:.9} {} at {}",
1404                data_id.to_string(),
1405                self.key(),
1406                self.ts
1407            );
1408            self.log.record_txn(self.ts, &txn);
1409            if self.rng.next_u64() % 3 == 0 {
1410                self.tidy.merge(apply.apply(&mut self.txns).await);
1411            }
1412            Ok(())
1413        }
1414
1415        async fn write_direct(&mut self, data_id: ShardId) -> Result<(), u64> {
1416            debug!(
1417                "stress write_direct {:.9} at {}",
1418                data_id.to_string(),
1419                self.ts
1420            );
1421            // First write an empty txn to ensure that the shard isn't
1422            // registered at this ts by someone else.
1423            self.txns.begin().commit_at(&mut self.txns, self.ts).await?;
1424
1425            let mut write = writer(&self.txns.datas.client, data_id).await;
1426            let mut current = write.shared_upper().into_option().unwrap();
1427            loop {
1428                if !(current <= self.ts) {
1429                    return Err(current);
1430                }
1431                let key = self.key();
1432                let updates = [((&key, &()), &self.ts, 1)];
1433                let res = crate::small_caa(
1434                    || format!("data {:.9} direct", data_id.to_string()),
1435                    &mut write,
1436                    &updates,
1437                    current,
1438                    self.ts + 1,
1439                )
1440                .await;
1441                match res {
1442                    Ok(()) => {
1443                        debug!("log {:.9} {} at {}", data_id.to_string(), key, self.ts);
1444                        self.log.record((data_id, key, self.ts, 1));
1445                        return Ok(());
1446                    }
1447                    Err(new_current) => current = new_current,
1448                }
1449            }
1450        }
1451
1452        async fn register(&mut self, data_id: ShardId) {
1453            self.retry_ts_err(&mut |w: &mut StressWorker| {
1454                debug!("stress register {:.9} at {}", data_id.to_string(), w.ts);
1455                Box::pin(async move {
1456                    let data_write = writer(&w.txns.datas.client, data_id).await;
1457                    let _ = w.txns.register(w.ts, [data_write]).await?;
1458                    Ok(())
1459                })
1460            })
1461            .await
1462        }
1463
1464        async fn forget(&mut self, data_id: ShardId) {
1465            self.retry_ts_err(&mut |w: &mut StressWorker| {
1466                debug!("stress forget {:.9} at {}", data_id.to_string(), w.ts);
1467                Box::pin(async move { w.txns.forget(w.ts, [data_id]).await.map(|_| ()) })
1468            })
1469            .await
1470        }
1471
1472        fn start_read(&mut self, data_id: ShardId, use_global_txn_cache: bool) {
1473            debug!(
1474                "stress start_read {:.9} at {}",
1475                data_id.to_string(),
1476                self.ts
1477            );
1478            let client = (*self.txns.datas.client).clone();
1479            let txns_id = self.txns.txns_id();
1480            let as_of = self.ts;
1481            debug!("start_read {:.9} as_of {}", data_id.to_string(), as_of);
1482            let (tx, mut rx) = oneshot::channel();
1483            let subscribe = mz_ore::task::spawn_blocking(
1484                || format!("{:.9}-{}", data_id.to_string(), as_of),
1485                move || {
1486                    let mut subscribe = DataSubscribe::new(
1487                        "test",
1488                        client,
1489                        txns_id,
1490                        data_id,
1491                        as_of,
1492                        Antichain::new(),
1493                        use_global_txn_cache,
1494                    );
1495                    let data_id = format!("{:.9}", data_id.to_string());
1496                    let _guard = info_span!("read_worker", %data_id, as_of).entered();
1497                    loop {
1498                        subscribe.worker.step_or_park(None);
1499                        subscribe.capture_output();
1500                        let until = match rx.try_recv() {
1501                            Ok(ts) => ts,
1502                            Err(oneshot::error::TryRecvError::Empty) => {
1503                                continue;
1504                            }
1505                            Err(oneshot::error::TryRecvError::Closed) => 0,
1506                        };
1507                        while subscribe.progress() < until {
1508                            subscribe.worker.step_or_park(None);
1509                            subscribe.capture_output();
1510                        }
1511                        return subscribe.output().clone();
1512                    }
1513                },
1514            );
1515            self.reads.push((tx, data_id, as_of, subscribe));
1516        }
1517
1518        async fn retry_ts_err<W>(&mut self, work_fn: &mut W)
1519        where
1520            W: for<'b> FnMut(&'b mut Self) -> BoxFuture<'b, Result<(), u64>>,
1521        {
1522            loop {
1523                match work_fn(self).await {
1524                    Ok(ret) => return ret,
1525                    Err(new_ts) => self.ts = new_ts,
1526                }
1527            }
1528        }
1529    }
1530
1531    #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
1532    #[cfg_attr(miri, ignore)] // unsupported operation: returning ready events from epoll_wait is not yet implemented
1533    async fn stress_correctness() {
1534        const NUM_DATA_SHARDS: usize = 2;
1535        const NUM_WORKERS: usize = 2;
1536        const NUM_STEPS_PER_WORKER: usize = 100;
1537        let seed = UNIX_EPOCH.elapsed().unwrap().hashed();
1538        eprintln!("using seed {}", seed);
1539
1540        let mut clients = PersistClientCache::new_no_metrics();
1541        // We disable pubsub below, so retune the listen retries (pubsub
1542        // fallback) to keep the test speedy.
1543        clients
1544            .cfg()
1545            .set_next_listen_batch_retryer(RetryParameters {
1546                fixed_sleep: Duration::ZERO,
1547                initial_backoff: Duration::from_millis(1),
1548                multiplier: 1,
1549                clamp: Duration::from_millis(1),
1550            });
1551        let client = clients.open(PersistLocation::new_in_mem()).await.unwrap();
1552        let mut txns = TxnsHandle::expect_open(client.clone()).await;
1553        let log = txns.new_log();
1554        let data_ids = (0..NUM_DATA_SHARDS)
1555            .map(|_| ShardId::new())
1556            .collect::<Vec<_>>();
1557        let data_writes = data_ids
1558            .iter()
1559            .map(|data_id| writer(&client, *data_id))
1560            .collect::<FuturesUnordered<_>>()
1561            .collect::<Vec<_>>()
1562            .await;
1563        let _data_sinces = data_ids
1564            .iter()
1565            .map(|data_id| reader(&client, *data_id))
1566            .collect::<FuturesUnordered<_>>()
1567            .collect::<Vec<_>>()
1568            .await;
1569        let register_ts = 1;
1570        txns.register(register_ts, data_writes).await.unwrap();
1571
1572        let mut workers = Vec::new();
1573        for idx in 0..NUM_WORKERS {
1574            // Clear the state cache between each client to maximally disconnect
1575            // them from each other.
1576            clients.clear_state_cache();
1577            let client = clients.open(PersistLocation::new_in_mem()).await.unwrap();
1578            let mut worker = StressWorker {
1579                idx,
1580                log: log.clone(),
1581                txns: TxnsHandle::expect_open_id(client.clone(), txns.txns_id()).await,
1582                data_ids: data_ids.clone(),
1583                tidy: Tidy::default(),
1584                ts: register_ts,
1585                step: 0,
1586                rng: SmallRng::seed_from_u64(seed.wrapping_add(u64::cast_from(idx))),
1587                reads: Vec::new(),
1588            };
1589            let worker = async move {
1590                while worker.step < NUM_STEPS_PER_WORKER {
1591                    worker.step().await;
1592                }
1593                (worker.ts, worker.reads)
1594            }
1595            .instrument(info_span!("stress_worker", idx));
1596            workers.push(mz_ore::task::spawn(|| format!("worker-{}", idx), worker));
1597        }
1598
1599        let mut max_ts = 0;
1600        let mut reads = Vec::new();
1601        for worker in workers {
1602            let (t, mut r) = worker.await.unwrap();
1603            max_ts = std::cmp::max(max_ts, t);
1604            reads.append(&mut r);
1605        }
1606
1607        // Run all of the following in a timeout to make hangs easier to debug.
1608        tokio::time::timeout(Duration::from_secs(30), async {
1609            info!("finished with max_ts of {}", max_ts);
1610            txns.apply_le(&max_ts).await;
1611            for data_id in data_ids {
1612                info!("reading data shard {}", data_id);
1613                log.assert_snapshot(data_id, max_ts)
1614                    .instrument(info_span!("read_data", data_id = format!("{:.9}", data_id)))
1615                    .await;
1616            }
1617            info!("now waiting for reads {}", max_ts);
1618            for (tx, data_id, as_of, subscribe) in reads {
1619                let _ = tx.send(max_ts + 1);
1620                let output = subscribe.await.unwrap();
1621                log.assert_eq(data_id, as_of, max_ts + 1, output);
1622            }
1623        })
1624        .await
1625        .unwrap();
1626    }
1627
1628    #[mz_ore::test(tokio::test)]
1629    #[cfg_attr(miri, ignore)] // unsupported operation: returning ready events from epoll_wait is not yet implemented
1630    async fn advance_physical_uppers_past() {
1631        let client = PersistClient::new_for_tests().await;
1632        let mut txns = TxnsHandle::expect_open(client.clone()).await;
1633        let log = txns.new_log();
1634        let d0 = txns.expect_register(1).await;
1635        let mut d0_write = writer(&client, d0).await;
1636        let d1 = txns.expect_register(2).await;
1637        let mut d1_write = writer(&client, d1).await;
1638
1639        assert_eq!(d0_write.fetch_recent_upper().await.elements(), &[2]);
1640        assert_eq!(d1_write.fetch_recent_upper().await.elements(), &[3]);
1641
1642        // Normal `apply` (used by expect_commit_at) does not advance the
1643        // physical upper of data shards that were not involved in the txn (lazy
1644        // upper). d1 is not involved in this txn so stays where it is.
1645        txns.expect_commit_at(3, d0, &["0-2"], &log).await;
1646        assert_eq!(d0_write.fetch_recent_upper().await.elements(), &[4]);
1647        assert_eq!(d1_write.fetch_recent_upper().await.elements(), &[3]);
1648
1649        // d0 is not involved in this txn so stays where it is.
1650        txns.expect_commit_at(4, d1, &["1-3"], &log).await;
1651        assert_eq!(d0_write.fetch_recent_upper().await.elements(), &[4]);
1652        assert_eq!(d1_write.fetch_recent_upper().await.elements(), &[5]);
1653
1654        log.assert_snapshot(d0, 4).await;
1655        log.assert_snapshot(d1, 4).await;
1656    }
1657
1658    #[mz_ore::test(tokio::test)]
1659    #[cfg_attr(miri, ignore)]
1660    #[allow(clippy::unnecessary_get_then_check)] // Makes it less readable.
1661    async fn schemas() {
1662        let client = PersistClient::new_for_tests().await;
1663        let mut txns0 = TxnsHandle::expect_open(client.clone()).await;
1664        let mut txns1 = TxnsHandle::expect_open_id(client.clone(), txns0.txns_id()).await;
1665        let log = txns0.new_log();
1666        let d0 = txns0.expect_register(1).await;
1667
1668        // The register call happened on txns0, which means it has a real schema
1669        // and can commit batches.
1670        assert!(txns0.datas.data_write_for_commit.get(&d0).is_some());
1671        let mut txn = txns0.begin_test();
1672        txn.write(&d0, "foo".into(), (), 1).await;
1673        let apply = txn.commit_at(&mut txns0, 2).await.unwrap();
1674        log.record_txn(2, &txn);
1675
1676        // We can use handle without a register call to apply a committed txn.
1677        assert!(txns1.datas.data_write_for_commit.get(&d0).is_none());
1678        let _tidy = apply.apply(&mut txns1).await;
1679
1680        // However, it cannot commit batches.
1681        assert!(txns1.datas.data_write_for_commit.get(&d0).is_none());
1682        let res = mz_ore::task::spawn(|| "test", async move {
1683            let mut txn = txns1.begin();
1684            txn.write(&d0, "bar".into(), (), 1).await;
1685            // This panics.
1686            let _ = txn.commit_at(&mut txns1, 3).await;
1687        });
1688        assert!(res.await.is_err());
1689
1690        // Forgetting the data shard removes it, so we don't leave the schema
1691        // sitting around.
1692        assert!(txns0.datas.data_write_for_commit.get(&d0).is_some());
1693        txns0.forget(3, [d0]).await.unwrap();
1694        assert_none!(txns0.datas.data_write_for_commit.get(&d0));
1695
1696        // Forget is idempotent.
1697        assert_none!(txns0.datas.data_write_for_commit.get(&d0));
1698        txns0.forget(4, [d0]).await.unwrap();
1699        assert_none!(txns0.datas.data_write_for_commit.get(&d0));
1700
1701        // We can register it again and commit again.
1702        assert_none!(txns0.datas.data_write_for_commit.get(&d0));
1703        txns0
1704            .register(5, [writer(&client, d0).await])
1705            .await
1706            .unwrap();
1707        assert!(txns0.datas.data_write_for_commit.get(&d0).is_some());
1708        txns0.expect_commit_at(6, d0, &["baz"], &log).await;
1709
1710        log.assert_snapshot(d0, 6).await;
1711    }
1712}