mz_txn_wal/
txns.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! An interface for atomic multi-shard writes.
11
12use std::collections::BTreeMap;
13use std::fmt::Debug;
14use std::ops::{Deref, DerefMut};
15use std::sync::Arc;
16
17use differential_dataflow::difference::Monoid;
18use differential_dataflow::lattice::Lattice;
19use futures::StreamExt;
20use futures::stream::FuturesUnordered;
21use mz_dyncfg::{Config, ConfigSet, ConfigValHandle};
22use mz_ore::collections::HashSet;
23use mz_ore::instrument;
24use mz_persist_client::batch::Batch;
25use mz_persist_client::cfg::USE_CRITICAL_SINCE_TXN;
26use mz_persist_client::critical::SinceHandle;
27use mz_persist_client::write::WriteHandle;
28use mz_persist_client::{Diagnostics, PersistClient, ShardId};
29use mz_persist_types::schema::SchemaId;
30use mz_persist_types::txn::{TxnsCodec, TxnsEntry};
31use mz_persist_types::{Codec, Codec64, Opaque, StepForward};
32use timely::order::TotalOrder;
33use timely::progress::Timestamp;
34use tracing::debug;
35
36use crate::TxnsCodecDefault;
37use crate::metrics::Metrics;
38use crate::txn_cache::{TxnsCache, Unapplied};
39use crate::txn_write::Txn;
40
41/// An interface for atomic multi-shard writes.
42///
43/// This handle is acquired through [Self::open]. Any data shards must be
44/// registered with [Self::register] before use. Transactions are then started
45/// with [Self::begin].
46///
47/// # Implementation Details
48///
49/// The structure of the txns shard is `(ShardId, Vec<u8>)` updates.
50///
51/// The core mechanism is that a txn commits a set of transmittable persist
52/// _batch handles_ as `(ShardId, <opaque blob>)` pairs at a single timestamp.
53/// This contractually both commits the txn and advances the logical upper of
54/// _every_ data shard (not just the ones involved in the txn).
55///
56/// Example:
57///
58/// ```text
59/// // A txn to only d0 at ts=1
60/// (d0, <opaque blob A>, 1, 1)
61/// // A txn to d0 (two blobs) and d1 (one blob) at ts=4
62/// (d0, <opaque blob B>, 4, 1)
63/// (d0, <opaque blob C>, 4, 1)
64/// (d1, <opaque blob D>, 4, 1)
65/// ```
66///
67/// However, the new commit is not yet readable until the txn apply has run,
68/// which is expected to be promptly done by the committer, except in the event
69/// of a crash. This, in ts order, moves the batch handles into the data shards
70/// with a [compare_and_append_batch] (similar to how the multi-worker
71/// persist_sink works).
72///
73/// [compare_and_append_batch]:
74///     mz_persist_client::write::WriteHandle::compare_and_append_batch
75///
76/// Once apply is run, we "tidy" the txns shard by retracting the update adding
77/// the batch. As a result, the contents of the txns shard at any given
78/// timestamp is exactly the set of outstanding apply work (plus registrations,
79/// see below).
80///
81/// Example (building on the above):
82///
83/// ```text
84/// // Tidy for the first txn at ts=3
85/// (d0, <opaque blob A>, 3, -1)
86/// // Tidy for the second txn (the timestamps can be different for each
87/// // retraction in a txn, but don't need to be)
88/// (d0, <opaque blob B>, 5, -1)
89/// (d0, <opaque blob C>, 6, -1)
90/// (d1, <opaque blob D>, 6, -1)
91/// ```
92///
93/// To make it easy to reason about exactly which data shards are registered in
94/// the txn set at any given moment, the data shard is added to the set with a
95/// `(ShardId, <empty>)` pair. The data may not be read before the timestamp of
96/// the update (which starts at the time it was initialized, but it may later be
97/// forwarded).
98///
99/// Example (building on both of the above):
100///
101/// ```text
102/// // d0 and d1 were both initialized before they were used above
103/// (d0, <empty>, 0, 1)
104/// (d1, <empty>, 2, 1)
105/// ```
106#[derive(Debug)]
107pub struct TxnsHandle<K: Codec, V: Codec, T, D, O = u64, C: TxnsCodec = TxnsCodecDefault> {
108    pub(crate) metrics: Arc<Metrics>,
109    pub(crate) txns_cache: TxnsCache<T, C>,
110    pub(crate) txns_write: WriteHandle<C::Key, C::Val, T, i64>,
111    pub(crate) txns_since: SinceHandle<C::Key, C::Val, T, i64, O>,
112    pub(crate) datas: DataHandles<K, V, T, D>,
113}
114
115impl<K, V, T, D, O, C> TxnsHandle<K, V, T, D, O, C>
116where
117    K: Debug + Codec,
118    V: Debug + Codec,
119    T: Timestamp + Lattice + TotalOrder + StepForward + Codec64 + Sync,
120    D: Debug + Monoid + Ord + Codec64 + Send + Sync,
121    O: Opaque + Debug + Codec64,
122    C: TxnsCodec,
123{
124    /// Returns a [TxnsHandle] committing to the given txn shard.
125    ///
126    /// `txns_id` identifies which shard will be used as the txns WAL. MZ will
127    /// likely have one of these per env, used by all processes and the same
128    /// across restarts.
129    ///
130    /// This also does any (idempotent) initialization work: i.e. ensures that
131    /// the txn shard is readable at `init_ts` by appending an empty batch, if
132    /// necessary.
133    pub async fn open(
134        init_ts: T,
135        client: PersistClient,
136        dyncfgs: ConfigSet,
137        metrics: Arc<Metrics>,
138        txns_id: ShardId,
139    ) -> Self {
140        let (txns_key_schema, txns_val_schema) = C::schemas();
141        let (mut txns_write, txns_read) = client
142            .open(
143                txns_id,
144                Arc::new(txns_key_schema),
145                Arc::new(txns_val_schema),
146                Diagnostics {
147                    shard_name: "txns".to_owned(),
148                    handle_purpose: "commit txns".to_owned(),
149                },
150                USE_CRITICAL_SINCE_TXN.get(client.dyncfgs()),
151            )
152            .await
153            .expect("txns schema shouldn't change");
154        let txns_since = client
155            .open_critical_since(
156                txns_id,
157                // TODO: We likely need to use a different critical reader id
158                // for this if we want to be able to introspect it via SQL.
159                PersistClient::CONTROLLER_CRITICAL_SINCE,
160                Diagnostics {
161                    shard_name: "txns".to_owned(),
162                    handle_purpose: "commit txns".to_owned(),
163                },
164            )
165            .await
166            .expect("txns schema shouldn't change");
167        let txns_cache = TxnsCache::init(init_ts, txns_read, &mut txns_write).await;
168        TxnsHandle {
169            metrics,
170            txns_cache,
171            txns_write,
172            txns_since,
173            datas: DataHandles {
174                dyncfgs,
175                client: Arc::new(client),
176                data_write_for_apply: BTreeMap::new(),
177                data_write_for_commit: BTreeMap::new(),
178            },
179        }
180    }
181
182    /// Returns a new, empty transaction that can involve the data shards
183    /// registered with this handle.
184    pub fn begin(&self) -> Txn<K, V, T, D> {
185        // TODO: This is a method on the handle because we'll need WriteHandles
186        // once we start spilling to s3.
187        Txn::new()
188    }
189
190    /// Registers data shards for use with this txn set.
191    ///
192    /// A registration entry is written to the txn shard. If it is not possible
193    /// to register the data at the requested time, an Err will be returned with
194    /// the minimum time the data shards could be registered.
195    ///
196    /// This method is idempotent. Data shards currently registered at
197    /// `register_ts` will not be registered a second time. Specifically, this
198    /// method will return success when the most recent register ts `R` is
199    /// less_equal to `register_ts` AND there is no forget ts between `R` and
200    /// `register_ts`.
201    ///
202    /// As a side effect all txns <= register_ts are applied, including the
203    /// registration itself.
204    ///
205    /// **WARNING!** While a data shard is registered to the txn set, writing to
206    /// it directly (i.e. using a WriteHandle instead of the TxnHandle,
207    /// registering it with another txn shard) will lead to incorrectness,
208    /// undefined behavior, and (potentially sticky) panics.
209    #[instrument(level = "debug", fields(ts = ?register_ts))]
210    pub async fn register(
211        &mut self,
212        register_ts: T,
213        data_writes: impl IntoIterator<Item = WriteHandle<K, V, T, D>>,
214    ) -> Result<Tidy, T> {
215        let op = &Arc::clone(&self.metrics).register;
216        op.run(async {
217            let mut data_writes = data_writes.into_iter().collect::<Vec<_>>();
218
219            // The txns system requires that all participating data shards have a
220            // schema registered. Importantly, we must register a data shard's
221            // schema _before_ we publish it to the txns shard.
222            for data_write in &mut data_writes {
223                data_write.ensure_schema_registered().await;
224            }
225
226            let updates = data_writes
227                .iter()
228                .map(|data_write| {
229                    let data_id = data_write.shard_id();
230                    let entry = TxnsEntry::Register(data_id, T::encode(&register_ts));
231                    (data_id, C::encode(entry))
232                })
233                .collect::<Vec<_>>();
234            let data_ids_debug = || {
235                data_writes
236                    .iter()
237                    .map(|x| format!("{:.9}", x.shard_id().to_string()))
238                    .collect::<Vec<_>>()
239                    .join(" ")
240            };
241
242            let mut txns_upper = self
243                .txns_write
244                .shared_upper()
245                .into_option()
246                .expect("txns should not be closed");
247            loop {
248                txns_upper = self.txns_cache.update_ge(&txns_upper).await.clone();
249                // Figure out which are still unregistered as of `txns_upper`. Below
250                // we write conditionally on the upper being what we expect so than
251                // we can re-run this if anything changes from underneath us.
252                let updates = updates
253                    .iter()
254                    .flat_map(|(data_id, (key, val))| {
255                        let registered =
256                            self.txns_cache.registered_at_progress(data_id, &txns_upper);
257                        (!registered).then_some(((key, val), &register_ts, 1))
258                    })
259                    .collect::<Vec<_>>();
260                // If the txns_upper has passed register_ts, we can no longer write.
261                if register_ts < txns_upper {
262                    debug!(
263                        "txns register {} at {:?} mismatch current={:?}",
264                        data_ids_debug(),
265                        register_ts,
266                        txns_upper,
267                    );
268                    return Err(txns_upper);
269                }
270
271                let res = crate::small_caa(
272                    || format!("txns register {}", data_ids_debug()),
273                    &mut self.txns_write,
274                    &updates,
275                    txns_upper,
276                    register_ts.step_forward(),
277                )
278                .await;
279                match res {
280                    Ok(()) => {
281                        debug!(
282                            "txns register {} at {:?} success",
283                            data_ids_debug(),
284                            register_ts
285                        );
286                        break;
287                    }
288                    Err(new_txns_upper) => {
289                        self.metrics.register.retry_count.inc();
290                        txns_upper = new_txns_upper;
291                        continue;
292                    }
293                }
294            }
295            for data_write in data_writes {
296                // If we already have a write handle for a newer version of a table, don't replace
297                // it! Currently we only support adding columns to tables with a default value, so
298                // the latest/newest schema will always be the most complete.
299                //
300                // TODO(alter_table): Revisit when we support dropping columns.
301                match self.datas.data_write_for_commit.get(&data_write.shard_id()) {
302                    None => {
303                        self.datas
304                            .data_write_for_commit
305                            .insert(data_write.shard_id(), DataWriteCommit(data_write));
306                    }
307                    Some(previous) => {
308                        let new_schema_id = data_write.schema_id().expect("ensured above");
309
310                        if let Some(prev_schema_id) = previous.schema_id()
311                            && prev_schema_id > new_schema_id
312                        {
313                            mz_ore::soft_panic_or_log!(
314                                "tried registering a WriteHandle with an older SchemaId; \
315                                 prev_schema_id: {} new_schema_id: {} shard_id: {}",
316                                prev_schema_id,
317                                new_schema_id,
318                                previous.shard_id(),
319                            );
320                            continue;
321                        } else if previous.schema_id().is_none() {
322                            mz_ore::soft_panic_or_log!(
323                                "encountered data shard without a schema; shard_id: {}",
324                                previous.shard_id(),
325                            );
326                        }
327
328                        tracing::info!(
329                            prev_schema_id = ?previous.schema_id(),
330                            ?new_schema_id,
331                            shard_id = %previous.shard_id(),
332                            "replacing WriteHandle"
333                        );
334                        self.datas
335                            .data_write_for_commit
336                            .insert(data_write.shard_id(), DataWriteCommit(data_write));
337                    }
338                }
339            }
340            let tidy = self.apply_le(&register_ts).await;
341
342            Ok(tidy)
343        })
344        .await
345    }
346
347    /// Removes data shards from use with this txn set.
348    ///
349    /// The registration entry written to the txn shard is retracted. If it is
350    /// not possible to forget the data shard at the requested time, an Err will
351    /// be returned with the minimum time the data shards could be forgotten.
352    ///
353    /// This method is idempotent. Data shards currently forgotten at
354    /// `forget_ts` will not be forgotten a second time. Specifically, this
355    /// method will return success when the most recent forget ts (if any) `F`
356    /// is less_equal to `forget_ts` AND there is no register ts between `F` and
357    /// `forget_ts`.
358    ///
359    /// As a side effect all txns <= forget_ts are applied, including the
360    /// forget itself.
361    ///
362    /// **WARNING!** While a data shard is registered to the txn set, writing to
363    /// it directly (i.e. using a WriteHandle instead of the TxnHandle,
364    /// registering it with another txn shard) will lead to incorrectness,
365    /// undefined behavior, and (potentially sticky) panics.
366    #[instrument(level = "debug", fields(ts = ?forget_ts))]
367    pub async fn forget(
368        &mut self,
369        forget_ts: T,
370        data_ids: impl IntoIterator<Item = ShardId>,
371    ) -> Result<Tidy, T> {
372        let op = &Arc::clone(&self.metrics).forget;
373        op.run(async {
374            let data_ids = data_ids.into_iter().collect::<Vec<_>>();
375            let mut txns_upper = self
376                .txns_write
377                .shared_upper()
378                .into_option()
379                .expect("txns should not be closed");
380            loop {
381                txns_upper = self.txns_cache.update_ge(&txns_upper).await.clone();
382
383                let data_ids_debug = || {
384                    data_ids
385                        .iter()
386                        .map(|x| format!("{:.9}", x.to_string()))
387                        .collect::<Vec<_>>()
388                        .join(" ")
389                };
390                let updates = data_ids
391                    .iter()
392                    // Never registered or already forgotten. This could change in
393                    // `[txns_upper, forget_ts]` (due to races) so close off that
394                    // interval before returning, just don't write any updates.
395                    .filter(|data_id| self.txns_cache.registered_at_progress(data_id, &txns_upper))
396                    .map(|data_id| C::encode(TxnsEntry::Register(*data_id, T::encode(&forget_ts))))
397                    .collect::<Vec<_>>();
398                let updates = updates
399                    .iter()
400                    .map(|(key, val)| ((key, val), &forget_ts, -1))
401                    .collect::<Vec<_>>();
402
403                // If the txns_upper has passed forget_ts, we can no longer write.
404                if forget_ts < txns_upper {
405                    debug!(
406                        "txns forget {} at {:?} mismatch current={:?}",
407                        data_ids_debug(),
408                        forget_ts,
409                        txns_upper,
410                    );
411                    return Err(txns_upper);
412                }
413
414                // Ensure the latest writes for each shard has been applied, so we don't run into
415                // any issues trying to apply it later.
416                {
417                    let data_ids: HashSet<_> = data_ids.iter().cloned().collect();
418                    let data_latest_unapplied = self
419                        .txns_cache
420                        .unapplied_batches
421                        .values()
422                        .rev()
423                        .find(|(x, _, _)| data_ids.contains(x));
424                    if let Some((_, _, latest_write)) = data_latest_unapplied {
425                        debug!(
426                            "txns forget {} applying latest write {:?}",
427                            data_ids_debug(),
428                            latest_write,
429                        );
430                        let latest_write = latest_write.clone();
431                        let _tidy = self.apply_le(&latest_write).await;
432                    }
433                }
434                let res = crate::small_caa(
435                    || format!("txns forget {}", data_ids_debug()),
436                    &mut self.txns_write,
437                    &updates,
438                    txns_upper,
439                    forget_ts.step_forward(),
440                )
441                .await;
442                match res {
443                    Ok(()) => {
444                        debug!(
445                            "txns forget {} at {:?} success",
446                            data_ids_debug(),
447                            forget_ts
448                        );
449                        break;
450                    }
451                    Err(new_txns_upper) => {
452                        self.metrics.forget.retry_count.inc();
453                        txns_upper = new_txns_upper;
454                        continue;
455                    }
456                }
457            }
458
459            // Note: Ordering here matters, we want to generate the Tidy work _before_ removing the
460            // handle because the work will create a handle to the shard.
461            let tidy = self.apply_le(&forget_ts).await;
462            for data_id in &data_ids {
463                self.datas.data_write_for_commit.remove(data_id);
464            }
465
466            Ok(tidy)
467        })
468        .await
469    }
470
471    /// Forgets, at the given timestamp, every data shard that is registered.
472    /// Returns the ids of the forgotten shards. See [Self::forget].
473    #[instrument(level = "debug", fields(ts = ?forget_ts))]
474    pub async fn forget_all(&mut self, forget_ts: T) -> Result<(Vec<ShardId>, Tidy), T> {
475        let op = &Arc::clone(&self.metrics).forget_all;
476        op.run(async {
477            let mut txns_upper = self
478                .txns_write
479                .shared_upper()
480                .into_option()
481                .expect("txns should not be closed");
482            let registered = loop {
483                txns_upper = self.txns_cache.update_ge(&txns_upper).await.clone();
484
485                let registered = self.txns_cache.all_registered_at_progress(&txns_upper);
486                let data_ids_debug = || {
487                    registered
488                        .iter()
489                        .map(|x| format!("{:.9}", x.to_string()))
490                        .collect::<Vec<_>>()
491                        .join(" ")
492                };
493                let updates = registered
494                    .iter()
495                    .map(|data_id| {
496                        C::encode(crate::TxnsEntry::Register(*data_id, T::encode(&forget_ts)))
497                    })
498                    .collect::<Vec<_>>();
499                let updates = updates
500                    .iter()
501                    .map(|(key, val)| ((key, val), &forget_ts, -1))
502                    .collect::<Vec<_>>();
503
504                // If the txns_upper has passed forget_ts, we can no longer write.
505                if forget_ts < txns_upper {
506                    debug!(
507                        "txns forget_all {} at {:?} mismatch current={:?}",
508                        data_ids_debug(),
509                        forget_ts,
510                        txns_upper,
511                    );
512                    return Err(txns_upper);
513                }
514
515                // Ensure the latest write has been applied, so we don't run into
516                // any issues trying to apply it later.
517                //
518                // NB: It's _very_ important for correctness to get this from the
519                // unapplied batches (which compact themselves naturally) and not
520                // from the writes (which are artificially compacted based on when
521                // we need reads for).
522                let data_latest_unapplied = self.txns_cache.unapplied_batches.values().last();
523                if let Some((_, _, latest_write)) = data_latest_unapplied {
524                    debug!(
525                        "txns forget_all {} applying latest write {:?}",
526                        data_ids_debug(),
527                        latest_write,
528                    );
529                    let latest_write = latest_write.clone();
530                    let _tidy = self.apply_le(&latest_write).await;
531                }
532                let res = crate::small_caa(
533                    || format!("txns forget_all {}", data_ids_debug()),
534                    &mut self.txns_write,
535                    &updates,
536                    txns_upper,
537                    forget_ts.step_forward(),
538                )
539                .await;
540                match res {
541                    Ok(()) => {
542                        debug!(
543                            "txns forget_all {} at {:?} success",
544                            data_ids_debug(),
545                            forget_ts
546                        );
547                        break registered;
548                    }
549                    Err(new_txns_upper) => {
550                        self.metrics.forget_all.retry_count.inc();
551                        txns_upper = new_txns_upper;
552                        continue;
553                    }
554                }
555            };
556
557            for data_id in registered.iter() {
558                self.datas.data_write_for_commit.remove(data_id);
559            }
560            let tidy = self.apply_le(&forget_ts).await;
561
562            Ok((registered, tidy))
563        })
564        .await
565    }
566
567    /// "Applies" all committed txns <= the given timestamp, ensuring that reads
568    /// at that timestamp will not block.
569    ///
570    /// In the common case, the txn committer will have done this work and this
571    /// method will be a no-op, but it is not guaranteed. In the event of a
572    /// crash or race, this does whatever persist writes are necessary (and
573    /// returns the resulting maintenance work), which could be significant.
574    ///
575    /// If the requested timestamp has not yet been written, this could block
576    /// for an unbounded amount of time.
577    ///
578    /// This method is idempotent.
579    #[instrument(level = "debug", fields(ts = ?ts))]
580    pub async fn apply_le(&mut self, ts: &T) -> Tidy {
581        let op = &self.metrics.apply_le;
582        op.run(async {
583            debug!("apply_le {:?}", ts);
584            let _ = self.txns_cache.update_gt(ts).await;
585            self.txns_cache.update_gauges(&self.metrics);
586
587            let mut unapplied_by_data = BTreeMap::<_, Vec<_>>::new();
588            for (data_id, unapplied, unapplied_ts) in self.txns_cache.unapplied() {
589                if ts < unapplied_ts {
590                    break;
591                }
592                unapplied_by_data
593                    .entry(*data_id)
594                    .or_default()
595                    .push((unapplied, unapplied_ts));
596            }
597
598            let retractions = FuturesUnordered::new();
599            for (data_id, unapplied) in unapplied_by_data {
600                let mut data_write = self.datas.take_write_for_apply(&data_id).await;
601                retractions.push(async move {
602                    let mut ret = Vec::new();
603                    for (unapplied, unapplied_ts) in unapplied {
604                        match unapplied {
605                            Unapplied::RegisterForget => {
606                                let () = crate::empty_caa(
607                                    || {
608                                        format!(
609                                            "data {:.9} register/forget fill",
610                                            data_id.to_string()
611                                        )
612                                    },
613                                    &mut data_write,
614                                    unapplied_ts.clone(),
615                                )
616                                .await;
617                            }
618                            Unapplied::Batch(batch_raws) => {
619                                let batch_raws = batch_raws
620                                    .into_iter()
621                                    .map(|batch_raw| batch_raw.as_slice())
622                                    .collect();
623                                crate::apply_caa(
624                                    &mut data_write,
625                                    &batch_raws,
626                                    unapplied_ts.clone(),
627                                )
628                                .await;
629                                for batch_raw in batch_raws {
630                                    // NB: Protos are not guaranteed to exactly roundtrip the
631                                    // encoded bytes, so we intentionally use the raw batch so that
632                                    // it definitely retracts.
633                                    ret.push((
634                                        batch_raw.to_vec(),
635                                        (T::encode(unapplied_ts), data_id),
636                                    ));
637                                }
638                            }
639                        }
640                    }
641                    (data_write, ret)
642                });
643            }
644            let retractions = retractions.collect::<Vec<_>>().await;
645            let retractions = retractions
646                .into_iter()
647                .flat_map(|(data_write, retractions)| {
648                    self.datas.put_write_for_apply(data_write);
649                    retractions
650                })
651                .collect();
652
653            // Remove all the applied registers.
654            self.txns_cache.mark_register_applied(ts);
655
656            debug!("apply_le {:?} success", ts);
657            Tidy { retractions }
658        })
659        .await
660    }
661
662    /// Commits the tidy work at the given time.
663    ///
664    /// Mostly a helper to make it obvious that we can throw away the apply work
665    /// (and not get into an infinite cycle of tidy->apply->tidy).
666    #[cfg(test)]
667    pub async fn tidy_at(&mut self, tidy_ts: T, tidy: Tidy) -> Result<(), T> {
668        debug!("tidy at {:?}", tidy_ts);
669
670        let mut txn = self.begin();
671        txn.tidy(tidy);
672        // We just constructed this txn, so it couldn't have committed any
673        // batches, and thus there's nothing to apply. We're free to throw it
674        // away.
675        let apply = txn.commit_at(self, tidy_ts.clone()).await?;
676        assert!(apply.is_empty());
677
678        debug!("tidy at {:?} success", tidy_ts);
679        Ok(())
680    }
681
682    /// Allows compaction to the txns shard as well as internal representations,
683    /// losing the ability to answer queries about times less_than since_ts.
684    ///
685    /// In practice, this will likely only be called from the singleton
686    /// controller process.
687    pub async fn compact_to(&mut self, mut since_ts: T) {
688        let op = &self.metrics.compact_to;
689        op.run(async {
690            tracing::debug!("compact_to {:?}", since_ts);
691            let _ = self.txns_cache.update_gt(&since_ts).await;
692
693            // NB: A critical invariant for how this all works is that we never
694            // allow the since of the txns shard to pass any unapplied writes, so
695            // reduce it as necessary.
696            let min_unapplied_ts = self.txns_cache.min_unapplied_ts();
697            if min_unapplied_ts < &since_ts {
698                since_ts.clone_from(min_unapplied_ts);
699            }
700            crate::cads::<T, O, C>(&mut self.txns_since, since_ts).await;
701        })
702        .await
703    }
704
705    /// Returns the [ShardId] of the txns shard.
706    pub fn txns_id(&self) -> ShardId {
707        self.txns_write.shard_id()
708    }
709
710    /// Returns the [TxnsCache] used by this handle.
711    pub fn read_cache(&self) -> &TxnsCache<T, C> {
712        &self.txns_cache
713    }
714}
715
716/// A token representing maintenance writes (in particular, retractions) to the
717/// txns shard.
718///
719/// This can be written on its own with `TxnsHandle::tidy_at` or sidecar'd into
720/// a normal txn with [Txn::tidy].
721#[derive(Debug, Default)]
722pub struct Tidy {
723    pub(crate) retractions: BTreeMap<Vec<u8>, ([u8; 8], ShardId)>,
724}
725
726impl Tidy {
727    /// Merges the work represented by the other tidy into this one.
728    pub fn merge(&mut self, other: Tidy) {
729        self.retractions.extend(other.retractions)
730    }
731}
732
733/// A helper to make a more targeted mutable borrow of self.
734#[derive(Debug)]
735pub(crate) struct DataHandles<K: Codec, V: Codec, T, D> {
736    pub(crate) dyncfgs: ConfigSet,
737    pub(crate) client: Arc<PersistClient>,
738    /// See [DataWriteApply].
739    ///
740    /// This is lazily populated with the set of shards touched by `apply_le`.
741    data_write_for_apply: BTreeMap<ShardId, DataWriteApply<K, V, T, D>>,
742    /// See [DataWriteCommit].
743    ///
744    /// This contains the set of data shards registered but not yet forgotten
745    /// with this particular write handle.
746    ///
747    /// NB: In the common case, this and `_for_apply` will contain the same set
748    /// of shards, but this is not required. A shard can be in either and not
749    /// the other.
750    data_write_for_commit: BTreeMap<ShardId, DataWriteCommit<K, V, T, D>>,
751}
752
753impl<K, V, T, D> DataHandles<K, V, T, D>
754where
755    K: Debug + Codec,
756    V: Debug + Codec,
757    T: Timestamp + Lattice + TotalOrder + Codec64 + Sync,
758    D: Monoid + Ord + Codec64 + Send + Sync,
759{
760    async fn open_data_write_for_apply(&self, data_id: ShardId) -> DataWriteApply<K, V, T, D> {
761        let diagnostics = Diagnostics::from_purpose("txn data");
762        let schemas = self
763            .client
764            .latest_schema::<K, V, T, D>(data_id, diagnostics.clone())
765            .await
766            .expect("codecs have not changed");
767        let (key_schema, val_schema) = match schemas {
768            Some((_, key_schema, val_schema)) => (Arc::new(key_schema), Arc::new(val_schema)),
769            // We will always have at least one schema registered by the time we reach this point,
770            // because that is ensured at txn-registration time.
771            None => unreachable!("data shard {} should have a schema", data_id),
772        };
773        let wrapped = self
774            .client
775            .open_writer(data_id, key_schema, val_schema, diagnostics)
776            .await
777            .expect("schema shouldn't change");
778        DataWriteApply {
779            apply_ensure_schema_match: APPLY_ENSURE_SCHEMA_MATCH.handle(&self.dyncfgs),
780            client: Arc::clone(&self.client),
781            wrapped,
782        }
783    }
784
785    pub(crate) async fn take_write_for_apply(
786        &mut self,
787        data_id: &ShardId,
788    ) -> DataWriteApply<K, V, T, D> {
789        if let Some(data_write) = self.data_write_for_apply.remove(data_id) {
790            return data_write;
791        }
792        self.open_data_write_for_apply(*data_id).await
793    }
794
795    pub(crate) fn put_write_for_apply(&mut self, data_write: DataWriteApply<K, V, T, D>) {
796        self.data_write_for_apply
797            .insert(data_write.shard_id(), data_write);
798    }
799
800    pub(crate) fn take_write_for_commit(
801        &mut self,
802        data_id: &ShardId,
803    ) -> Option<DataWriteCommit<K, V, T, D>> {
804        self.data_write_for_commit.remove(data_id)
805    }
806
807    pub(crate) fn put_write_for_commit(&mut self, data_write: DataWriteCommit<K, V, T, D>) {
808        let prev = self
809            .data_write_for_commit
810            .insert(data_write.shard_id(), data_write);
811        assert!(prev.is_none());
812    }
813}
814
815/// A newtype wrapper around [WriteHandle] indicating that it has a real schema
816/// registered by the user.
817///
818/// The txn-wal user declares which schema they'd like to use for committing
819/// batches by passing it in (as part of the WriteHandle) in the call to
820/// register. This must be used to encode any new batches written. The wrapper
821/// helps us from accidentally mixing up the WriteHandles that we internally
822/// invent for applying the batches (which use a schema matching the one
823/// declared in the batch).
824#[derive(Debug)]
825pub(crate) struct DataWriteCommit<K: Codec, V: Codec, T, D>(pub(crate) WriteHandle<K, V, T, D>);
826
827impl<K: Codec, V: Codec, T, D> Deref for DataWriteCommit<K, V, T, D> {
828    type Target = WriteHandle<K, V, T, D>;
829
830    fn deref(&self) -> &Self::Target {
831        &self.0
832    }
833}
834
835impl<K: Codec, V: Codec, T, D> DerefMut for DataWriteCommit<K, V, T, D> {
836    fn deref_mut(&mut self) -> &mut Self::Target {
837        &mut self.0
838    }
839}
840
841/// A newtype wrapper around [WriteHandle] indicating that it can alter the
842/// schema its using to match the one in the batches being appended.
843///
844/// When a batch is committed to txn-wal, it contains metadata about which
845/// schemas were used to encode the data in it. Txn-wal then uses this info to
846/// make sure that in [TxnsHandle::apply_le], that the `compare_and_append` call
847/// happens on a handle with the same schema. This is accomplished by querying
848/// the persist schema registry.
849#[derive(Debug)]
850pub(crate) struct DataWriteApply<K: Codec, V: Codec, T, D> {
851    client: Arc<PersistClient>,
852    apply_ensure_schema_match: ConfigValHandle<bool>,
853    pub(crate) wrapped: WriteHandle<K, V, T, D>,
854}
855
856impl<K: Codec, V: Codec, T, D> Deref for DataWriteApply<K, V, T, D> {
857    type Target = WriteHandle<K, V, T, D>;
858
859    fn deref(&self) -> &Self::Target {
860        &self.wrapped
861    }
862}
863
864impl<K: Codec, V: Codec, T, D> DerefMut for DataWriteApply<K, V, T, D> {
865    fn deref_mut(&mut self) -> &mut Self::Target {
866        &mut self.wrapped
867    }
868}
869
870pub(crate) const APPLY_ENSURE_SCHEMA_MATCH: Config<bool> = Config::new(
871    "txn_wal_apply_ensure_schema_match",
872    true,
873    "CYA to skip updating write handle to batch schema in apply",
874);
875
876fn at_most_one_schema(
877    schemas: impl Iterator<Item = SchemaId>,
878) -> Result<Option<SchemaId>, (SchemaId, SchemaId)> {
879    let mut schema = None;
880    for s in schemas {
881        match schema {
882            None => schema = Some(s),
883            Some(x) if s != x => return Err((s, x)),
884            Some(_) => continue,
885        }
886    }
887    Ok(schema)
888}
889
890impl<K, V, T, D> DataWriteApply<K, V, T, D>
891where
892    K: Debug + Codec,
893    V: Debug + Codec,
894    T: Timestamp + Lattice + TotalOrder + Codec64 + Sync,
895    D: Monoid + Ord + Codec64 + Send + Sync,
896{
897    pub(crate) async fn maybe_replace_with_batch_schema(&mut self, batches: &[Batch<K, V, T, D>]) {
898        // TODO: Remove this once everything is rolled out and we're sure it's
899        // not going to cause any issues.
900        if !self.apply_ensure_schema_match.get() {
901            return;
902        }
903        let batch_schema = at_most_one_schema(batches.iter().flat_map(|x| x.schemas()));
904        let batch_schema = batch_schema.unwrap_or_else(|_| {
905            panic!(
906                "txn-wal uses at most one schema to commit batches, got: {:?}",
907                batches.iter().flat_map(|x| x.schemas()).collect::<Vec<_>>()
908            )
909        });
910        let (batch_schema, handle_schema) = match (batch_schema, self.wrapped.schema_id()) {
911            (Some(batch_schema), Some(handle_schema)) if batch_schema != handle_schema => {
912                (batch_schema, handle_schema)
913            }
914            _ => return,
915        };
916
917        let data_id = self.shard_id();
918        let diagnostics = Diagnostics::from_purpose("txn data");
919        let (key_schema, val_schema) = self
920            .client
921            .get_schema::<K, V, T, D>(data_id, batch_schema, diagnostics.clone())
922            .await
923            .expect("codecs shouldn't change")
924            .expect("id must have been registered to create this batch");
925        let new_data_write = self
926            .client
927            .open_writer(
928                self.shard_id(),
929                Arc::new(key_schema),
930                Arc::new(val_schema),
931                diagnostics,
932            )
933            .await
934            .expect("codecs shouldn't change");
935        tracing::info!(
936            "updated {} write handle from {} to {} to apply batches",
937            data_id,
938            handle_schema,
939            batch_schema
940        );
941        assert_eq!(new_data_write.schema_id(), Some(batch_schema));
942        self.wrapped = new_data_write;
943    }
944}
945
946#[cfg(test)]
947mod tests {
948    use std::time::{Duration, UNIX_EPOCH};
949
950    use differential_dataflow::Hashable;
951    use futures::future::BoxFuture;
952    use mz_ore::assert_none;
953    use mz_ore::cast::CastFrom;
954    use mz_ore::collections::CollectionExt;
955    use mz_ore::metrics::MetricsRegistry;
956    use mz_persist_client::PersistLocation;
957    use mz_persist_client::cache::PersistClientCache;
958    use mz_persist_client::cfg::RetryParameters;
959    use rand::rngs::SmallRng;
960    use rand::{RngCore, SeedableRng};
961    use timely::progress::Antichain;
962    use tokio::sync::oneshot;
963    use tracing::{Instrument, info, info_span};
964
965    use crate::operator::DataSubscribe;
966    use crate::tests::{CommitLog, reader, write_directly, writer};
967
968    use super::*;
969
970    impl TxnsHandle<String, (), u64, i64, u64, TxnsCodecDefault> {
971        pub(crate) async fn expect_open(client: PersistClient) -> Self {
972            Self::expect_open_id(client, ShardId::new()).await
973        }
974
975        pub(crate) async fn expect_open_id(client: PersistClient, txns_id: ShardId) -> Self {
976            let dyncfgs = crate::all_dyncfgs(client.dyncfgs().clone());
977            Self::open(
978                0,
979                client,
980                dyncfgs,
981                Arc::new(Metrics::new(&MetricsRegistry::new())),
982                txns_id,
983            )
984            .await
985        }
986
987        pub(crate) fn new_log(&self) -> CommitLog {
988            CommitLog::new((*self.datas.client).clone(), self.txns_id())
989        }
990
991        pub(crate) async fn expect_register(&mut self, register_ts: u64) -> ShardId {
992            self.expect_registers(register_ts, 1).await.into_element()
993        }
994
995        pub(crate) async fn expect_registers(
996            &mut self,
997            register_ts: u64,
998            amount: usize,
999        ) -> Vec<ShardId> {
1000            let data_ids: Vec<_> = (0..amount).map(|_| ShardId::new()).collect();
1001            let mut writers = Vec::new();
1002            for data_id in &data_ids {
1003                writers.push(writer(&self.datas.client, *data_id).await);
1004            }
1005            self.register(register_ts, writers).await.unwrap();
1006            data_ids
1007        }
1008
1009        pub(crate) async fn expect_commit_at(
1010            &mut self,
1011            commit_ts: u64,
1012            data_id: ShardId,
1013            keys: &[&str],
1014            log: &CommitLog,
1015        ) -> Tidy {
1016            let mut txn = self.begin();
1017            for key in keys {
1018                txn.write(&data_id, (*key).into(), (), 1).await;
1019            }
1020            let tidy = txn
1021                .commit_at(self, commit_ts)
1022                .await
1023                .unwrap()
1024                .apply(self)
1025                .await;
1026            for key in keys {
1027                log.record((data_id, (*key).into(), commit_ts, 1));
1028            }
1029            tidy
1030        }
1031    }
1032
1033    #[mz_ore::test(tokio::test)]
1034    #[cfg_attr(miri, ignore)] // too slow
1035    async fn register_at() {
1036        let client = PersistClient::new_for_tests().await;
1037        let mut txns = TxnsHandle::expect_open(client.clone()).await;
1038        let log = txns.new_log();
1039        let d0 = txns.expect_register(2).await;
1040
1041        // Register a second time is a no-op (idempotent).
1042        txns.register(3, [writer(&client, d0).await]).await.unwrap();
1043
1044        // Cannot register a new data shard at an already closed off time. An
1045        // error is returned with the first time that a registration would
1046        // succeed.
1047        let d1 = ShardId::new();
1048        assert_eq!(
1049            txns.register(2, [writer(&client, d1).await])
1050                .await
1051                .unwrap_err(),
1052            4
1053        );
1054
1055        // Can still register after txns have been committed.
1056        txns.expect_commit_at(4, d0, &["foo"], &log).await;
1057        txns.register(5, [writer(&client, d1).await]).await.unwrap();
1058
1059        // We can also register some new and some already registered shards.
1060        let d2 = ShardId::new();
1061        txns.register(6, [writer(&client, d0).await, writer(&client, d2).await])
1062            .await
1063            .unwrap();
1064
1065        let () = log.assert_snapshot(d0, 6).await;
1066        let () = log.assert_snapshot(d1, 6).await;
1067    }
1068
1069    /// A sanity check that CommitLog catches an incorrect usage (proxy for a
1070    /// bug that looks like an incorrect usage).
1071    #[mz_ore::test(tokio::test)]
1072    #[cfg_attr(miri, ignore)] // too slow
1073    #[should_panic(expected = "left: [(\"foo\", 2, 1)]\n right: [(\"foo\", 2, 2)]")]
1074    async fn incorrect_usage_register_write_same_time() {
1075        let client = PersistClient::new_for_tests().await;
1076        let mut txns = TxnsHandle::expect_open(client.clone()).await;
1077        let log = txns.new_log();
1078        let d0 = txns.expect_register(1).await;
1079        let mut d0_write = writer(&client, d0).await;
1080
1081        // Commit a write at ts 2...
1082        let mut txn = txns.begin_test();
1083        txn.write(&d0, "foo".into(), (), 1).await;
1084        let apply = txn.commit_at(&mut txns, 2).await.unwrap();
1085        log.record_txn(2, &txn);
1086        // ... and (incorrectly) also write to the shard normally at ts 2.
1087        let () = d0_write
1088            .compare_and_append(
1089                &[(("foo".to_owned(), ()), 2, 1)],
1090                Antichain::from_elem(2),
1091                Antichain::from_elem(3),
1092            )
1093            .await
1094            .unwrap()
1095            .unwrap();
1096        log.record((d0, "foo".into(), 2, 1));
1097        apply.apply(&mut txns).await;
1098
1099        // Verify that CommitLog catches this.
1100        log.assert_snapshot(d0, 2).await;
1101    }
1102
1103    #[mz_ore::test(tokio::test)]
1104    #[cfg_attr(miri, ignore)] // too slow
1105    async fn forget_at() {
1106        let client = PersistClient::new_for_tests().await;
1107        let mut txns = TxnsHandle::expect_open(client.clone()).await;
1108        let log = txns.new_log();
1109
1110        // Can forget a data_shard that has not been registered.
1111        txns.forget(1, [ShardId::new()]).await.unwrap();
1112
1113        // Can forget multiple data_shards that have not been registered.
1114        txns.forget(2, (0..5).map(|_| ShardId::new()))
1115            .await
1116            .unwrap();
1117
1118        // Can forget a registered shard.
1119        let d0 = txns.expect_register(3).await;
1120        txns.forget(4, [d0]).await.unwrap();
1121
1122        // Can forget multiple registered shards.
1123        let ds = txns.expect_registers(5, 5).await;
1124        txns.forget(6, ds.clone()).await.unwrap();
1125
1126        // Forget is idempotent.
1127        txns.forget(7, [d0]).await.unwrap();
1128        txns.forget(8, ds.clone()).await.unwrap();
1129
1130        // Cannot forget at an already closed off time. An error is returned
1131        // with the first time that a registration would succeed.
1132        let d1 = txns.expect_register(9).await;
1133        assert_eq!(txns.forget(9, [d1]).await.unwrap_err(), 10);
1134
1135        // Write to txns and to d0 directly.
1136        let mut d0_write = writer(&client, d0).await;
1137        txns.expect_commit_at(10, d1, &["d1"], &log).await;
1138        let updates = [(("d0".to_owned(), ()), 10, 1)];
1139        d0_write
1140            .compare_and_append(&updates, d0_write.shared_upper(), Antichain::from_elem(11))
1141            .await
1142            .unwrap()
1143            .unwrap();
1144        log.record((d0, "d0".into(), 10, 1));
1145
1146        // Can register and forget an already registered and forgotten shard.
1147        txns.register(11, [writer(&client, d0).await])
1148            .await
1149            .unwrap();
1150        let mut forget_expected = vec![d0, d1];
1151        forget_expected.sort();
1152        assert_eq!(txns.forget_all(12).await.unwrap().0, forget_expected);
1153
1154        // Close shard to writes
1155        d0_write
1156            .compare_and_append_batch(&mut [], d0_write.shared_upper(), Antichain::new(), true)
1157            .await
1158            .unwrap()
1159            .unwrap();
1160
1161        let () = log.assert_snapshot(d0, 12).await;
1162        let () = log.assert_snapshot(d1, 12).await;
1163
1164        for di in ds {
1165            let mut di_write = writer(&client, di).await;
1166
1167            // Close shards to writes
1168            di_write
1169                .compare_and_append_batch(&mut [], di_write.shared_upper(), Antichain::new(), true)
1170                .await
1171                .unwrap()
1172                .unwrap();
1173
1174            let () = log.assert_snapshot(di, 8).await;
1175        }
1176    }
1177
1178    #[mz_ore::test(tokio::test)]
1179    #[cfg_attr(miri, ignore)] // too slow
1180    async fn register_forget() {
1181        async fn step_some_past(subs: &mut Vec<DataSubscribe>, ts: u64) {
1182            for (idx, sub) in subs.iter_mut().enumerate() {
1183                // Only step some of them to try to maximize edge cases.
1184                if usize::cast_from(ts) % (idx + 1) == 0 {
1185                    async {
1186                        info!("stepping sub {} past {}", idx, ts);
1187                        sub.step_past(ts).await;
1188                    }
1189                    .instrument(info_span!("sub", idx))
1190                    .await;
1191                }
1192            }
1193        }
1194
1195        let client = PersistClient::new_for_tests().await;
1196        let mut txns = TxnsHandle::expect_open(client.clone()).await;
1197        let log = txns.new_log();
1198        let d0 = ShardId::new();
1199        let mut d0_write = writer(&client, d0).await;
1200        let mut subs = Vec::new();
1201
1202        // Loop for a while doing the following:
1203        // - Write directly to some time before register ts
1204        // - Register
1205        // - Write via txns
1206        // - Forget
1207        //
1208        // After each step, make sure that a subscription started at that time
1209        // works and that all subscriptions can be stepped through the expected
1210        // timestamp.
1211        let mut ts = 0;
1212        while ts < 32 {
1213            subs.push(txns.read_cache().expect_subscribe(&client, d0, ts));
1214            ts += 1;
1215            info!("{} direct", ts);
1216            txns.begin().commit_at(&mut txns, ts).await.unwrap();
1217            write_directly(ts, &mut d0_write, &[&format!("d{}", ts)], &log).await;
1218            step_some_past(&mut subs, ts).await;
1219            if ts % 11 == 0 {
1220                txns.compact_to(ts).await;
1221            }
1222
1223            subs.push(txns.read_cache().expect_subscribe(&client, d0, ts));
1224            ts += 1;
1225            info!("{} register", ts);
1226            txns.register(ts, [writer(&client, d0).await])
1227                .await
1228                .unwrap();
1229            step_some_past(&mut subs, ts).await;
1230            if ts % 11 == 0 {
1231                txns.compact_to(ts).await;
1232            }
1233
1234            subs.push(txns.read_cache().expect_subscribe(&client, d0, ts));
1235            ts += 1;
1236            info!("{} txns", ts);
1237            txns.expect_commit_at(ts, d0, &[&format!("t{}", ts)], &log)
1238                .await;
1239            step_some_past(&mut subs, ts).await;
1240            if ts % 11 == 0 {
1241                txns.compact_to(ts).await;
1242            }
1243
1244            subs.push(txns.read_cache().expect_subscribe(&client, d0, ts));
1245            ts += 1;
1246            info!("{} forget", ts);
1247            txns.forget(ts, [d0]).await.unwrap();
1248            step_some_past(&mut subs, ts).await;
1249            if ts % 11 == 0 {
1250                txns.compact_to(ts).await;
1251            }
1252        }
1253
1254        // Check all the subscribes.
1255        for mut sub in subs.into_iter() {
1256            sub.step_past(ts).await;
1257            log.assert_eq(d0, sub.as_of, sub.progress(), sub.output().clone());
1258        }
1259    }
1260
1261    // Regression test for a bug encountered during initial development:
1262    // - task 0 commits to a data shard at ts T
1263    // - before task 0 can unblock T for reads, task 1 tries to register it at
1264    //   time T (ditto T+X), this does a caa of empty space, advancing the upper
1265    //   of the data shard to T+1
1266    // - task 0 attempts to caa in the batch, but finds out the upper is T+1 and
1267    //   assumes someone else did the work
1268    // - result: the write is lost
1269    #[mz_ore::test(tokio::test)]
1270    #[cfg_attr(miri, ignore)] // unsupported operation: returning ready events from epoll_wait is not yet implemented
1271    async fn race_data_shard_register_and_commit() {
1272        let client = PersistClient::new_for_tests().await;
1273        let mut txns = TxnsHandle::expect_open(client.clone()).await;
1274        let d0 = txns.expect_register(1).await;
1275
1276        let mut txn = txns.begin();
1277        txn.write(&d0, "foo".into(), (), 1).await;
1278        let commit_apply = txn.commit_at(&mut txns, 2).await.unwrap();
1279
1280        txns.register(3, [writer(&client, d0).await]).await.unwrap();
1281
1282        // Make sure that we can read empty at the register commit time even
1283        // before the txn commit apply.
1284        let actual = txns.txns_cache.expect_snapshot(&client, d0, 1).await;
1285        assert_eq!(actual, Vec::<String>::new());
1286
1287        commit_apply.apply(&mut txns).await;
1288        let actual = txns.txns_cache.expect_snapshot(&client, d0, 2).await;
1289        assert_eq!(actual, vec!["foo".to_owned()]);
1290    }
1291
1292    // A test that applies a batch of writes all at once.
1293    #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
1294    #[cfg_attr(miri, ignore)] // unsupported operation: returning ready events from epoll_wait is not yet implemented
1295    async fn apply_many_ts() {
1296        let client = PersistClient::new_for_tests().await;
1297        let mut txns = TxnsHandle::expect_open(client.clone()).await;
1298        let log = txns.new_log();
1299        let d0 = txns.expect_register(1).await;
1300
1301        for ts in 2..10 {
1302            let mut txn = txns.begin();
1303            txn.write(&d0, ts.to_string(), (), 1).await;
1304            let _apply = txn.commit_at(&mut txns, ts).await.unwrap();
1305            log.record((d0, ts.to_string(), ts, 1));
1306        }
1307        // This automatically runs the apply, which catches up all the previous
1308        // txns at once.
1309        txns.expect_commit_at(10, d0, &[], &log).await;
1310
1311        log.assert_snapshot(d0, 10).await;
1312    }
1313
1314    struct StressWorker {
1315        idx: usize,
1316        data_ids: Vec<ShardId>,
1317        txns: TxnsHandle<String, (), u64, i64>,
1318        log: CommitLog,
1319        tidy: Tidy,
1320        ts: u64,
1321        step: usize,
1322        rng: SmallRng,
1323        reads: Vec<(
1324            oneshot::Sender<u64>,
1325            ShardId,
1326            u64,
1327            mz_ore::task::JoinHandle<Vec<(String, u64, i64)>>,
1328        )>,
1329    }
1330
1331    impl StressWorker {
1332        pub async fn step(&mut self) {
1333            debug!(
1334                "stress {} step {} START ts={}",
1335                self.idx, self.step, self.ts
1336            );
1337            let data_id =
1338                self.data_ids[usize::cast_from(self.rng.next_u64()) % self.data_ids.len()];
1339            match self.rng.next_u64() % 6 {
1340                0 => self.write(data_id).await,
1341                // The register and forget impls intentionally don't switch on
1342                // whether it's already registered to stress idempotence.
1343                1 => self.register(data_id).await,
1344                2 => self.forget(data_id).await,
1345                3 => {
1346                    debug!("stress update {:.9} to {}", data_id.to_string(), self.ts);
1347                    let _ = self.txns.txns_cache.update_ge(&self.ts).await;
1348                }
1349                4 => self.start_read(data_id),
1350                5 => self.start_read(data_id),
1351                _ => unreachable!(""),
1352            }
1353            debug!("stress {} step {} DONE ts={}", self.idx, self.step, self.ts);
1354            self.step += 1;
1355        }
1356
1357        fn key(&self) -> String {
1358            format!("w{}s{}", self.idx, self.step)
1359        }
1360
1361        async fn registered_at_progress_ts(&mut self, data_id: ShardId) -> bool {
1362            self.ts = *self.txns.txns_cache.update_ge(&self.ts).await;
1363            self.txns
1364                .txns_cache
1365                .registered_at_progress(&data_id, &self.ts)
1366        }
1367
1368        // Writes to the given data shard, either via txns if it's registered or
1369        // directly if it's not.
1370        async fn write(&mut self, data_id: ShardId) {
1371            // Make sure to keep the registered_at_ts call _inside_ the retry
1372            // loop, because a data shard might switch between registered or not
1373            // registered as the loop advances through timestamps.
1374            self.retry_ts_err(&mut |w: &mut StressWorker| {
1375                Box::pin(async move {
1376                    if w.registered_at_progress_ts(data_id).await {
1377                        w.write_via_txns(data_id).await
1378                    } else {
1379                        w.write_direct(data_id).await
1380                    }
1381                })
1382            })
1383            .await
1384        }
1385
1386        async fn write_via_txns(&mut self, data_id: ShardId) -> Result<(), u64> {
1387            debug!(
1388                "stress write_via_txns {:.9} at {}",
1389                data_id.to_string(),
1390                self.ts
1391            );
1392            // HACK: Normally, we'd make sure that this particular handle had
1393            // registered the data shard before writing to it, but that would
1394            // consume a ts and isn't quite how we want `write_via_txns` to
1395            // work. Work around that by setting a write handle (with a schema
1396            // that we promise is correct) in the right place.
1397            if !self.txns.datas.data_write_for_commit.contains_key(&data_id) {
1398                let x = writer(&self.txns.datas.client, data_id).await;
1399                self.txns
1400                    .datas
1401                    .data_write_for_commit
1402                    .insert(data_id, DataWriteCommit(x));
1403            }
1404            let mut txn = self.txns.begin_test();
1405            txn.tidy(std::mem::take(&mut self.tidy));
1406            txn.write(&data_id, self.key(), (), 1).await;
1407            let apply = txn.commit_at(&mut self.txns, self.ts).await?;
1408            debug!(
1409                "log {:.9} {} at {}",
1410                data_id.to_string(),
1411                self.key(),
1412                self.ts
1413            );
1414            self.log.record_txn(self.ts, &txn);
1415            if self.rng.next_u64() % 3 == 0 {
1416                self.tidy.merge(apply.apply(&mut self.txns).await);
1417            }
1418            Ok(())
1419        }
1420
1421        async fn write_direct(&mut self, data_id: ShardId) -> Result<(), u64> {
1422            debug!(
1423                "stress write_direct {:.9} at {}",
1424                data_id.to_string(),
1425                self.ts
1426            );
1427            // First write an empty txn to ensure that the shard isn't
1428            // registered at this ts by someone else.
1429            self.txns.begin().commit_at(&mut self.txns, self.ts).await?;
1430
1431            let mut write = writer(&self.txns.datas.client, data_id).await;
1432            let mut current = write.shared_upper().into_option().unwrap();
1433            loop {
1434                if !(current <= self.ts) {
1435                    return Err(current);
1436                }
1437                let key = self.key();
1438                let updates = [((&key, &()), &self.ts, 1)];
1439                let res = crate::small_caa(
1440                    || format!("data {:.9} direct", data_id.to_string()),
1441                    &mut write,
1442                    &updates,
1443                    current,
1444                    self.ts + 1,
1445                )
1446                .await;
1447                match res {
1448                    Ok(()) => {
1449                        debug!("log {:.9} {} at {}", data_id.to_string(), key, self.ts);
1450                        self.log.record((data_id, key, self.ts, 1));
1451                        return Ok(());
1452                    }
1453                    Err(new_current) => current = new_current,
1454                }
1455            }
1456        }
1457
1458        async fn register(&mut self, data_id: ShardId) {
1459            self.retry_ts_err(&mut |w: &mut StressWorker| {
1460                debug!("stress register {:.9} at {}", data_id.to_string(), w.ts);
1461                Box::pin(async move {
1462                    let data_write = writer(&w.txns.datas.client, data_id).await;
1463                    let _ = w.txns.register(w.ts, [data_write]).await?;
1464                    Ok(())
1465                })
1466            })
1467            .await
1468        }
1469
1470        async fn forget(&mut self, data_id: ShardId) {
1471            self.retry_ts_err(&mut |w: &mut StressWorker| {
1472                debug!("stress forget {:.9} at {}", data_id.to_string(), w.ts);
1473                Box::pin(async move { w.txns.forget(w.ts, [data_id]).await.map(|_| ()) })
1474            })
1475            .await
1476        }
1477
1478        fn start_read(&mut self, data_id: ShardId) {
1479            debug!(
1480                "stress start_read {:.9} at {}",
1481                data_id.to_string(),
1482                self.ts
1483            );
1484            let client = (*self.txns.datas.client).clone();
1485            let txns_id = self.txns.txns_id();
1486            let as_of = self.ts;
1487            debug!("start_read {:.9} as_of {}", data_id.to_string(), as_of);
1488            let (tx, mut rx) = oneshot::channel();
1489            let subscribe = mz_ore::task::spawn_blocking(
1490                || format!("{:.9}-{}", data_id.to_string(), as_of),
1491                move || {
1492                    let mut subscribe = DataSubscribe::new(
1493                        "test",
1494                        client,
1495                        txns_id,
1496                        data_id,
1497                        as_of,
1498                        Antichain::new(),
1499                    );
1500                    let data_id = format!("{:.9}", data_id.to_string());
1501                    let _guard = info_span!("read_worker", %data_id, as_of).entered();
1502                    loop {
1503                        subscribe.worker.step_or_park(None);
1504                        subscribe.capture_output();
1505                        let until = match rx.try_recv() {
1506                            Ok(ts) => ts,
1507                            Err(oneshot::error::TryRecvError::Empty) => {
1508                                continue;
1509                            }
1510                            Err(oneshot::error::TryRecvError::Closed) => 0,
1511                        };
1512                        while subscribe.progress() < until {
1513                            subscribe.worker.step_or_park(None);
1514                            subscribe.capture_output();
1515                        }
1516                        return subscribe.output().clone();
1517                    }
1518                },
1519            );
1520            self.reads.push((tx, data_id, as_of, subscribe));
1521        }
1522
1523        async fn retry_ts_err<W>(&mut self, work_fn: &mut W)
1524        where
1525            W: for<'b> FnMut(&'b mut Self) -> BoxFuture<'b, Result<(), u64>>,
1526        {
1527            loop {
1528                match work_fn(self).await {
1529                    Ok(ret) => return ret,
1530                    Err(new_ts) => self.ts = new_ts,
1531                }
1532            }
1533        }
1534    }
1535
1536    #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
1537    #[cfg_attr(miri, ignore)] // unsupported operation: returning ready events from epoll_wait is not yet implemented
1538    async fn stress_correctness() {
1539        const NUM_DATA_SHARDS: usize = 2;
1540        const NUM_WORKERS: usize = 2;
1541        const NUM_STEPS_PER_WORKER: usize = 100;
1542        let seed = UNIX_EPOCH.elapsed().unwrap().hashed();
1543        eprintln!("using seed {}", seed);
1544
1545        let mut clients = PersistClientCache::new_no_metrics();
1546        // We disable pubsub below, so retune the listen retries (pubsub
1547        // fallback) to keep the test speedy.
1548        clients
1549            .cfg()
1550            .set_next_listen_batch_retryer(RetryParameters {
1551                fixed_sleep: Duration::ZERO,
1552                initial_backoff: Duration::from_millis(1),
1553                multiplier: 1,
1554                clamp: Duration::from_millis(1),
1555            });
1556        let client = clients.open(PersistLocation::new_in_mem()).await.unwrap();
1557        let mut txns = TxnsHandle::expect_open(client.clone()).await;
1558        let log = txns.new_log();
1559        let data_ids = (0..NUM_DATA_SHARDS)
1560            .map(|_| ShardId::new())
1561            .collect::<Vec<_>>();
1562        let data_writes = data_ids
1563            .iter()
1564            .map(|data_id| writer(&client, *data_id))
1565            .collect::<FuturesUnordered<_>>()
1566            .collect::<Vec<_>>()
1567            .await;
1568        let _data_sinces = data_ids
1569            .iter()
1570            .map(|data_id| reader(&client, *data_id))
1571            .collect::<FuturesUnordered<_>>()
1572            .collect::<Vec<_>>()
1573            .await;
1574        let register_ts = 1;
1575        txns.register(register_ts, data_writes).await.unwrap();
1576
1577        let mut workers = Vec::new();
1578        for idx in 0..NUM_WORKERS {
1579            // Clear the state cache between each client to maximally disconnect
1580            // them from each other.
1581            clients.clear_state_cache();
1582            let client = clients.open(PersistLocation::new_in_mem()).await.unwrap();
1583            let mut worker = StressWorker {
1584                idx,
1585                log: log.clone(),
1586                txns: TxnsHandle::expect_open_id(client.clone(), txns.txns_id()).await,
1587                data_ids: data_ids.clone(),
1588                tidy: Tidy::default(),
1589                ts: register_ts,
1590                step: 0,
1591                rng: SmallRng::seed_from_u64(seed.wrapping_add(u64::cast_from(idx))),
1592                reads: Vec::new(),
1593            };
1594            let worker = async move {
1595                while worker.step < NUM_STEPS_PER_WORKER {
1596                    worker.step().await;
1597                }
1598                (worker.ts, worker.reads)
1599            }
1600            .instrument(info_span!("stress_worker", idx));
1601            workers.push(mz_ore::task::spawn(|| format!("worker-{}", idx), worker));
1602        }
1603
1604        let mut max_ts = 0;
1605        let mut reads = Vec::new();
1606        for worker in workers {
1607            let (t, mut r) = worker.await.unwrap();
1608            max_ts = std::cmp::max(max_ts, t);
1609            reads.append(&mut r);
1610        }
1611
1612        // Run all of the following in a timeout to make hangs easier to debug.
1613        tokio::time::timeout(Duration::from_secs(30), async {
1614            info!("finished with max_ts of {}", max_ts);
1615            txns.apply_le(&max_ts).await;
1616            for data_id in data_ids {
1617                info!("reading data shard {}", data_id);
1618                log.assert_snapshot(data_id, max_ts)
1619                    .instrument(info_span!("read_data", data_id = format!("{:.9}", data_id)))
1620                    .await;
1621            }
1622            info!("now waiting for reads {}", max_ts);
1623            for (tx, data_id, as_of, subscribe) in reads {
1624                let _ = tx.send(max_ts + 1);
1625                let output = subscribe.await.unwrap();
1626                log.assert_eq(data_id, as_of, max_ts + 1, output);
1627            }
1628        })
1629        .await
1630        .unwrap();
1631    }
1632
1633    #[mz_ore::test(tokio::test)]
1634    #[cfg_attr(miri, ignore)] // unsupported operation: returning ready events from epoll_wait is not yet implemented
1635    async fn advance_physical_uppers_past() {
1636        let client = PersistClient::new_for_tests().await;
1637        let mut txns = TxnsHandle::expect_open(client.clone()).await;
1638        let log = txns.new_log();
1639        let d0 = txns.expect_register(1).await;
1640        let mut d0_write = writer(&client, d0).await;
1641        let d1 = txns.expect_register(2).await;
1642        let mut d1_write = writer(&client, d1).await;
1643
1644        assert_eq!(d0_write.fetch_recent_upper().await.elements(), &[2]);
1645        assert_eq!(d1_write.fetch_recent_upper().await.elements(), &[3]);
1646
1647        // Normal `apply` (used by expect_commit_at) does not advance the
1648        // physical upper of data shards that were not involved in the txn (lazy
1649        // upper). d1 is not involved in this txn so stays where it is.
1650        txns.expect_commit_at(3, d0, &["0-2"], &log).await;
1651        assert_eq!(d0_write.fetch_recent_upper().await.elements(), &[4]);
1652        assert_eq!(d1_write.fetch_recent_upper().await.elements(), &[3]);
1653
1654        // d0 is not involved in this txn so stays where it is.
1655        txns.expect_commit_at(4, d1, &["1-3"], &log).await;
1656        assert_eq!(d0_write.fetch_recent_upper().await.elements(), &[4]);
1657        assert_eq!(d1_write.fetch_recent_upper().await.elements(), &[5]);
1658
1659        log.assert_snapshot(d0, 4).await;
1660        log.assert_snapshot(d1, 4).await;
1661    }
1662
1663    #[mz_ore::test(tokio::test)]
1664    #[cfg_attr(miri, ignore)]
1665    #[allow(clippy::unnecessary_get_then_check)] // Makes it less readable.
1666    async fn schemas() {
1667        let client = PersistClient::new_for_tests().await;
1668        let mut txns0 = TxnsHandle::expect_open(client.clone()).await;
1669        let mut txns1 = TxnsHandle::expect_open_id(client.clone(), txns0.txns_id()).await;
1670        let log = txns0.new_log();
1671        let d0 = txns0.expect_register(1).await;
1672
1673        // The register call happened on txns0, which means it has a real schema
1674        // and can commit batches.
1675        assert!(txns0.datas.data_write_for_commit.get(&d0).is_some());
1676        let mut txn = txns0.begin_test();
1677        txn.write(&d0, "foo".into(), (), 1).await;
1678        let apply = txn.commit_at(&mut txns0, 2).await.unwrap();
1679        log.record_txn(2, &txn);
1680
1681        // We can use handle without a register call to apply a committed txn.
1682        assert!(txns1.datas.data_write_for_commit.get(&d0).is_none());
1683        let _tidy = apply.apply(&mut txns1).await;
1684
1685        // However, it cannot commit batches.
1686        assert!(txns1.datas.data_write_for_commit.get(&d0).is_none());
1687        let res = mz_ore::task::spawn(|| "test", async move {
1688            let mut txn = txns1.begin();
1689            txn.write(&d0, "bar".into(), (), 1).await;
1690            // This panics.
1691            let _ = txn.commit_at(&mut txns1, 3).await;
1692        });
1693        assert!(res.await.is_err());
1694
1695        // Forgetting the data shard removes it, so we don't leave the schema
1696        // sitting around.
1697        assert!(txns0.datas.data_write_for_commit.get(&d0).is_some());
1698        txns0.forget(3, [d0]).await.unwrap();
1699        assert_none!(txns0.datas.data_write_for_commit.get(&d0));
1700
1701        // Forget is idempotent.
1702        assert_none!(txns0.datas.data_write_for_commit.get(&d0));
1703        txns0.forget(4, [d0]).await.unwrap();
1704        assert_none!(txns0.datas.data_write_for_commit.get(&d0));
1705
1706        // We can register it again and commit again.
1707        assert_none!(txns0.datas.data_write_for_commit.get(&d0));
1708        txns0
1709            .register(5, [writer(&client, d0).await])
1710            .await
1711            .unwrap();
1712        assert!(txns0.datas.data_write_for_commit.get(&d0).is_some());
1713        txns0.expect_commit_at(6, d0, &["baz"], &log).await;
1714
1715        log.assert_snapshot(d0, 6).await;
1716    }
1717}