Skip to main content

mz_txn_wal/
txns.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! An interface for atomic multi-shard writes.
11
12use std::collections::BTreeMap;
13use std::fmt::Debug;
14use std::ops::{Deref, DerefMut};
15use std::sync::Arc;
16
17use differential_dataflow::difference::Monoid;
18use differential_dataflow::lattice::Lattice;
19use futures::StreamExt;
20use futures::stream::FuturesUnordered;
21use mz_dyncfg::{Config, ConfigSet, ConfigValHandle};
22use mz_ore::collections::HashSet;
23use mz_ore::instrument;
24use mz_persist_client::batch::Batch;
25use mz_persist_client::cfg::USE_CRITICAL_SINCE_TXN;
26use mz_persist_client::critical::{Opaque, SinceHandle};
27use mz_persist_client::write::WriteHandle;
28use mz_persist_client::{Diagnostics, PersistClient, ShardId};
29use mz_persist_types::schema::SchemaId;
30use mz_persist_types::txn::{TxnsCodec, TxnsEntry};
31use mz_persist_types::{Codec, Codec64, StepForward};
32use timely::order::TotalOrder;
33use timely::progress::Timestamp;
34use tracing::debug;
35
36use crate::TxnsCodecDefault;
37use crate::metrics::Metrics;
38use crate::txn_cache::{TxnsCache, Unapplied};
39use crate::txn_write::Txn;
40
41/// An interface for atomic multi-shard writes.
42///
43/// This handle is acquired through [Self::open]. Any data shards must be
44/// registered with [Self::register] before use. Transactions are then started
45/// with [Self::begin].
46///
47/// # Implementation Details
48///
49/// The structure of the txns shard is `(ShardId, Vec<u8>)` updates.
50///
51/// The core mechanism is that a txn commits a set of transmittable persist
52/// _batch handles_ as `(ShardId, <opaque blob>)` pairs at a single timestamp.
53/// This contractually both commits the txn and advances the logical upper of
54/// _every_ data shard (not just the ones involved in the txn).
55///
56/// Example:
57///
58/// ```text
59/// // A txn to only d0 at ts=1
60/// (d0, <opaque blob A>, 1, 1)
61/// // A txn to d0 (two blobs) and d1 (one blob) at ts=4
62/// (d0, <opaque blob B>, 4, 1)
63/// (d0, <opaque blob C>, 4, 1)
64/// (d1, <opaque blob D>, 4, 1)
65/// ```
66///
67/// However, the new commit is not yet readable until the txn apply has run,
68/// which is expected to be promptly done by the committer, except in the event
69/// of a crash. This, in ts order, moves the batch handles into the data shards
70/// with a [compare_and_append_batch] (similar to how the multi-worker
71/// persist_sink works).
72///
73/// [compare_and_append_batch]:
74///     mz_persist_client::write::WriteHandle::compare_and_append_batch
75///
76/// Once apply is run, we "tidy" the txns shard by retracting the update adding
77/// the batch. As a result, the contents of the txns shard at any given
78/// timestamp is exactly the set of outstanding apply work (plus registrations,
79/// see below).
80///
81/// Example (building on the above):
82///
83/// ```text
84/// // Tidy for the first txn at ts=3
85/// (d0, <opaque blob A>, 3, -1)
86/// // Tidy for the second txn (the timestamps can be different for each
87/// // retraction in a txn, but don't need to be)
88/// (d0, <opaque blob B>, 5, -1)
89/// (d0, <opaque blob C>, 6, -1)
90/// (d1, <opaque blob D>, 6, -1)
91/// ```
92///
93/// To make it easy to reason about exactly which data shards are registered in
94/// the txn set at any given moment, the data shard is added to the set with a
95/// `(ShardId, <empty>)` pair. The data may not be read before the timestamp of
96/// the update (which starts at the time it was initialized, but it may later be
97/// forwarded).
98///
99/// Example (building on both of the above):
100///
101/// ```text
102/// // d0 and d1 were both initialized before they were used above
103/// (d0, <empty>, 0, 1)
104/// (d1, <empty>, 2, 1)
105/// ```
106#[derive(Debug)]
107pub struct TxnsHandle<K: Codec, V: Codec, T, D, C: TxnsCodec = TxnsCodecDefault> {
108    pub(crate) metrics: Arc<Metrics>,
109    pub(crate) txns_cache: TxnsCache<T, C>,
110    pub(crate) txns_write: WriteHandle<C::Key, C::Val, T, i64>,
111    pub(crate) txns_since: SinceHandle<C::Key, C::Val, T, i64>,
112    pub(crate) datas: DataHandles<K, V, T, D>,
113}
114
115impl<K, V, T, D, C> TxnsHandle<K, V, T, D, C>
116where
117    K: Debug + Codec,
118    V: Debug + Codec,
119    T: Timestamp + Lattice + TotalOrder + StepForward + Codec64 + Sync,
120    D: Debug + Monoid + Ord + Codec64 + Send + Sync,
121    C: TxnsCodec,
122{
123    /// Returns a [TxnsHandle] committing to the given txn shard.
124    ///
125    /// `txns_id` identifies which shard will be used as the txns WAL. MZ will
126    /// likely have one of these per env, used by all processes and the same
127    /// across restarts.
128    ///
129    /// This also does any (idempotent) initialization work: i.e. ensures that
130    /// the txn shard is readable at `init_ts` by appending an empty batch, if
131    /// necessary.
132    pub async fn open(
133        init_ts: T,
134        client: PersistClient,
135        dyncfgs: ConfigSet,
136        metrics: Arc<Metrics>,
137        txns_id: ShardId,
138        opaque: Opaque,
139    ) -> Self {
140        let (txns_key_schema, txns_val_schema) = C::schemas();
141        let (mut txns_write, txns_read) = client
142            .open(
143                txns_id,
144                Arc::new(txns_key_schema),
145                Arc::new(txns_val_schema),
146                Diagnostics {
147                    shard_name: "txns".to_owned(),
148                    handle_purpose: "commit txns".to_owned(),
149                },
150                USE_CRITICAL_SINCE_TXN.get(client.dyncfgs()),
151            )
152            .await
153            .expect("txns schema shouldn't change");
154        let txns_since = client
155            .open_critical_since(
156                txns_id,
157                // TODO: We likely need to use a different critical reader id
158                // for this if we want to be able to introspect it via SQL.
159                PersistClient::CONTROLLER_CRITICAL_SINCE,
160                opaque,
161                Diagnostics {
162                    shard_name: "txns".to_owned(),
163                    handle_purpose: "commit txns".to_owned(),
164                },
165            )
166            .await
167            .expect("txns schema shouldn't change");
168        let txns_cache = TxnsCache::init(init_ts, txns_read, &mut txns_write).await;
169        TxnsHandle {
170            metrics,
171            txns_cache,
172            txns_write,
173            txns_since,
174            datas: DataHandles {
175                dyncfgs,
176                client: Arc::new(client),
177                data_write_for_apply: BTreeMap::new(),
178                data_write_for_commit: BTreeMap::new(),
179            },
180        }
181    }
182
183    /// Returns a new, empty transaction that can involve the data shards
184    /// registered with this handle.
185    pub fn begin(&self) -> Txn<K, V, T, D> {
186        // TODO: This is a method on the handle because we'll need WriteHandles
187        // once we start spilling to s3.
188        Txn::new()
189    }
190
191    /// Registers data shards for use with this txn set.
192    ///
193    /// A registration entry is written to the txn shard. If it is not possible
194    /// to register the data at the requested time, an Err will be returned with
195    /// the minimum time the data shards could be registered.
196    ///
197    /// This method is idempotent. Data shards currently registered at
198    /// `register_ts` will not be registered a second time. Specifically, this
199    /// method will return success when the most recent register ts `R` is
200    /// less_equal to `register_ts` AND there is no forget ts between `R` and
201    /// `register_ts`.
202    ///
203    /// As a side effect all txns <= register_ts are applied, including the
204    /// registration itself.
205    ///
206    /// **WARNING!** While a data shard is registered to the txn set, writing to
207    /// it directly (i.e. using a WriteHandle instead of the TxnHandle,
208    /// registering it with another txn shard) will lead to incorrectness,
209    /// undefined behavior, and (potentially sticky) panics.
210    #[instrument(level = "debug", fields(ts = ?register_ts))]
211    pub async fn register(
212        &mut self,
213        register_ts: T,
214        data_writes: impl IntoIterator<Item = WriteHandle<K, V, T, D>>,
215    ) -> Result<Tidy, T> {
216        let op = &Arc::clone(&self.metrics).register;
217        op.run(async {
218            let mut data_writes = data_writes.into_iter().collect::<Vec<_>>();
219
220            // The txns system requires that all participating data shards have a
221            // schema registered. Importantly, we must register a data shard's
222            // schema _before_ we publish it to the txns shard.
223            for data_write in &mut data_writes {
224                // Note that if this fails we'll bail out farther down in any case,
225                // so we might as well fail fast.
226                data_write
227                    .try_register_schema()
228                    .await
229                    .expect("schema should be registered");
230            }
231
232            let updates = data_writes
233                .iter()
234                .map(|data_write| {
235                    let data_id = data_write.shard_id();
236                    let entry = TxnsEntry::Register(data_id, T::encode(&register_ts));
237                    (data_id, C::encode(entry))
238                })
239                .collect::<Vec<_>>();
240            let data_ids_debug = || {
241                data_writes
242                    .iter()
243                    .map(|x| format!("{:.9}", x.shard_id().to_string()))
244                    .collect::<Vec<_>>()
245                    .join(" ")
246            };
247
248            let mut txns_upper = self
249                .txns_write
250                .shared_upper()
251                .into_option()
252                .expect("txns should not be closed");
253            loop {
254                txns_upper = self.txns_cache.update_ge(&txns_upper).await.clone();
255                // Figure out which are still unregistered as of `txns_upper`. Below
256                // we write conditionally on the upper being what we expect so than
257                // we can re-run this if anything changes from underneath us.
258                let updates = updates
259                    .iter()
260                    .flat_map(|(data_id, (key, val))| {
261                        let registered =
262                            self.txns_cache.registered_at_progress(data_id, &txns_upper);
263                        (!registered).then_some(((key, val), &register_ts, 1))
264                    })
265                    .collect::<Vec<_>>();
266                // If the txns_upper has passed register_ts, we can no longer write.
267                if register_ts < txns_upper {
268                    debug!(
269                        "txns register {} at {:?} mismatch current={:?}",
270                        data_ids_debug(),
271                        register_ts,
272                        txns_upper,
273                    );
274                    return Err(txns_upper);
275                }
276
277                let res = crate::small_caa(
278                    || format!("txns register {}", data_ids_debug()),
279                    &mut self.txns_write,
280                    &updates,
281                    txns_upper,
282                    register_ts.step_forward(),
283                )
284                .await;
285                match res {
286                    Ok(()) => {
287                        debug!(
288                            "txns register {} at {:?} success",
289                            data_ids_debug(),
290                            register_ts
291                        );
292                        break;
293                    }
294                    Err(new_txns_upper) => {
295                        self.metrics.register.retry_count.inc();
296                        txns_upper = new_txns_upper;
297                        continue;
298                    }
299                }
300            }
301            for data_write in data_writes {
302                // If we already have a write handle for a newer version of a table, don't replace
303                // it! Currently we only support adding columns to tables with a default value, so
304                // the latest/newest schema will always be the most complete.
305                //
306                // TODO(alter_table): Revisit when we support dropping columns.
307                match self.datas.data_write_for_commit.get(&data_write.shard_id()) {
308                    None => {
309                        self.datas
310                            .data_write_for_commit
311                            .insert(data_write.shard_id(), DataWriteCommit(data_write));
312                    }
313                    Some(previous) => {
314                        let new_schema_id = data_write.schema_id().expect("ensured above");
315
316                        if let Some(prev_schema_id) = previous.schema_id()
317                            && prev_schema_id > new_schema_id
318                        {
319                            mz_ore::soft_panic_or_log!(
320                                "tried registering a WriteHandle with an older SchemaId; \
321                                 prev_schema_id: {} new_schema_id: {} shard_id: {}",
322                                prev_schema_id,
323                                new_schema_id,
324                                previous.shard_id(),
325                            );
326                            continue;
327                        } else if previous.schema_id().is_none() {
328                            mz_ore::soft_panic_or_log!(
329                                "encountered data shard without a schema; shard_id: {}",
330                                previous.shard_id(),
331                            );
332                        }
333
334                        tracing::info!(
335                            prev_schema_id = ?previous.schema_id(),
336                            ?new_schema_id,
337                            shard_id = %previous.shard_id(),
338                            "replacing WriteHandle"
339                        );
340                        self.datas
341                            .data_write_for_commit
342                            .insert(data_write.shard_id(), DataWriteCommit(data_write));
343                    }
344                }
345            }
346            let tidy = self.apply_le(&register_ts).await;
347
348            Ok(tidy)
349        })
350        .await
351    }
352
353    /// Removes data shards from use with this txn set.
354    ///
355    /// The registration entry written to the txn shard is retracted. If it is
356    /// not possible to forget the data shard at the requested time, an Err will
357    /// be returned with the minimum time the data shards could be forgotten.
358    ///
359    /// This method is idempotent. Data shards currently forgotten at
360    /// `forget_ts` will not be forgotten a second time. Specifically, this
361    /// method will return success when the most recent forget ts (if any) `F`
362    /// is less_equal to `forget_ts` AND there is no register ts between `F` and
363    /// `forget_ts`.
364    ///
365    /// As a side effect all txns <= forget_ts are applied, including the
366    /// forget itself.
367    ///
368    /// **WARNING!** While a data shard is registered to the txn set, writing to
369    /// it directly (i.e. using a WriteHandle instead of the TxnHandle,
370    /// registering it with another txn shard) will lead to incorrectness,
371    /// undefined behavior, and (potentially sticky) panics.
372    #[instrument(level = "debug", fields(ts = ?forget_ts))]
373    pub async fn forget(
374        &mut self,
375        forget_ts: T,
376        data_ids: impl IntoIterator<Item = ShardId>,
377    ) -> Result<Tidy, T> {
378        let op = &Arc::clone(&self.metrics).forget;
379        op.run(async {
380            let data_ids = data_ids.into_iter().collect::<Vec<_>>();
381            let mut txns_upper = self
382                .txns_write
383                .shared_upper()
384                .into_option()
385                .expect("txns should not be closed");
386            loop {
387                txns_upper = self.txns_cache.update_ge(&txns_upper).await.clone();
388
389                let data_ids_debug = || {
390                    data_ids
391                        .iter()
392                        .map(|x| format!("{:.9}", x.to_string()))
393                        .collect::<Vec<_>>()
394                        .join(" ")
395                };
396                let updates = data_ids
397                    .iter()
398                    // Never registered or already forgotten. This could change in
399                    // `[txns_upper, forget_ts]` (due to races) so close off that
400                    // interval before returning, just don't write any updates.
401                    .filter(|data_id| self.txns_cache.registered_at_progress(data_id, &txns_upper))
402                    .map(|data_id| C::encode(TxnsEntry::Register(*data_id, T::encode(&forget_ts))))
403                    .collect::<Vec<_>>();
404                let updates = updates
405                    .iter()
406                    .map(|(key, val)| ((key, val), &forget_ts, -1))
407                    .collect::<Vec<_>>();
408
409                // If the txns_upper has passed forget_ts, we can no longer write.
410                if forget_ts < txns_upper {
411                    debug!(
412                        "txns forget {} at {:?} mismatch current={:?}",
413                        data_ids_debug(),
414                        forget_ts,
415                        txns_upper,
416                    );
417                    return Err(txns_upper);
418                }
419
420                // Ensure the latest writes for each shard has been applied, so we don't run into
421                // any issues trying to apply it later.
422                {
423                    let data_ids: HashSet<_> = data_ids.iter().cloned().collect();
424                    let data_latest_unapplied = self
425                        .txns_cache
426                        .unapplied_batches
427                        .values()
428                        .rev()
429                        .find(|(x, _, _)| data_ids.contains(x));
430                    if let Some((_, _, latest_write)) = data_latest_unapplied {
431                        debug!(
432                            "txns forget {} applying latest write {:?}",
433                            data_ids_debug(),
434                            latest_write,
435                        );
436                        let latest_write = latest_write.clone();
437                        let _tidy = self.apply_le(&latest_write).await;
438                    }
439                }
440                let res = crate::small_caa(
441                    || format!("txns forget {}", data_ids_debug()),
442                    &mut self.txns_write,
443                    &updates,
444                    txns_upper,
445                    forget_ts.step_forward(),
446                )
447                .await;
448                match res {
449                    Ok(()) => {
450                        debug!(
451                            "txns forget {} at {:?} success",
452                            data_ids_debug(),
453                            forget_ts
454                        );
455                        break;
456                    }
457                    Err(new_txns_upper) => {
458                        self.metrics.forget.retry_count.inc();
459                        txns_upper = new_txns_upper;
460                        continue;
461                    }
462                }
463            }
464
465            // Note: Ordering here matters, we want to generate the Tidy work _before_ removing the
466            // handle because the work will create a handle to the shard.
467            let tidy = self.apply_le(&forget_ts).await;
468            for data_id in &data_ids {
469                self.datas.data_write_for_commit.remove(data_id);
470            }
471
472            Ok(tidy)
473        })
474        .await
475    }
476
477    /// Forgets, at the given timestamp, every data shard that is registered.
478    /// Returns the ids of the forgotten shards. See [Self::forget].
479    #[instrument(level = "debug", fields(ts = ?forget_ts))]
480    pub async fn forget_all(&mut self, forget_ts: T) -> Result<(Vec<ShardId>, Tidy), T> {
481        let op = &Arc::clone(&self.metrics).forget_all;
482        op.run(async {
483            let mut txns_upper = self
484                .txns_write
485                .shared_upper()
486                .into_option()
487                .expect("txns should not be closed");
488            let registered = loop {
489                txns_upper = self.txns_cache.update_ge(&txns_upper).await.clone();
490
491                let registered = self.txns_cache.all_registered_at_progress(&txns_upper);
492                let data_ids_debug = || {
493                    registered
494                        .iter()
495                        .map(|x| format!("{:.9}", x.to_string()))
496                        .collect::<Vec<_>>()
497                        .join(" ")
498                };
499                let updates = registered
500                    .iter()
501                    .map(|data_id| {
502                        C::encode(crate::TxnsEntry::Register(*data_id, T::encode(&forget_ts)))
503                    })
504                    .collect::<Vec<_>>();
505                let updates = updates
506                    .iter()
507                    .map(|(key, val)| ((key, val), &forget_ts, -1))
508                    .collect::<Vec<_>>();
509
510                // If the txns_upper has passed forget_ts, we can no longer write.
511                if forget_ts < txns_upper {
512                    debug!(
513                        "txns forget_all {} at {:?} mismatch current={:?}",
514                        data_ids_debug(),
515                        forget_ts,
516                        txns_upper,
517                    );
518                    return Err(txns_upper);
519                }
520
521                // Ensure the latest write has been applied, so we don't run into
522                // any issues trying to apply it later.
523                //
524                // NB: It's _very_ important for correctness to get this from the
525                // unapplied batches (which compact themselves naturally) and not
526                // from the writes (which are artificially compacted based on when
527                // we need reads for).
528                let data_latest_unapplied = self.txns_cache.unapplied_batches.values().last();
529                if let Some((_, _, latest_write)) = data_latest_unapplied {
530                    debug!(
531                        "txns forget_all {} applying latest write {:?}",
532                        data_ids_debug(),
533                        latest_write,
534                    );
535                    let latest_write = latest_write.clone();
536                    let _tidy = self.apply_le(&latest_write).await;
537                }
538                let res = crate::small_caa(
539                    || format!("txns forget_all {}", data_ids_debug()),
540                    &mut self.txns_write,
541                    &updates,
542                    txns_upper,
543                    forget_ts.step_forward(),
544                )
545                .await;
546                match res {
547                    Ok(()) => {
548                        debug!(
549                            "txns forget_all {} at {:?} success",
550                            data_ids_debug(),
551                            forget_ts
552                        );
553                        break registered;
554                    }
555                    Err(new_txns_upper) => {
556                        self.metrics.forget_all.retry_count.inc();
557                        txns_upper = new_txns_upper;
558                        continue;
559                    }
560                }
561            };
562
563            for data_id in registered.iter() {
564                self.datas.data_write_for_commit.remove(data_id);
565            }
566            let tidy = self.apply_le(&forget_ts).await;
567
568            Ok((registered, tidy))
569        })
570        .await
571    }
572
573    /// "Applies" all committed txns <= the given timestamp, ensuring that reads
574    /// at that timestamp will not block.
575    ///
576    /// In the common case, the txn committer will have done this work and this
577    /// method will be a no-op, but it is not guaranteed. In the event of a
578    /// crash or race, this does whatever persist writes are necessary (and
579    /// returns the resulting maintenance work), which could be significant.
580    ///
581    /// If the requested timestamp has not yet been written, this could block
582    /// for an unbounded amount of time.
583    ///
584    /// This method is idempotent.
585    #[instrument(level = "debug", fields(ts = ?ts))]
586    pub async fn apply_le(&mut self, ts: &T) -> Tidy {
587        let op = &self.metrics.apply_le;
588        op.run(async {
589            debug!("apply_le {:?}", ts);
590            let _ = self.txns_cache.update_gt(ts).await;
591            self.txns_cache.update_gauges(&self.metrics);
592
593            let mut unapplied_by_data = BTreeMap::<_, Vec<_>>::new();
594            for (data_id, unapplied, unapplied_ts) in self.txns_cache.unapplied() {
595                if ts < unapplied_ts {
596                    break;
597                }
598                unapplied_by_data
599                    .entry(*data_id)
600                    .or_default()
601                    .push((unapplied, unapplied_ts));
602            }
603
604            let retractions = FuturesUnordered::new();
605            for (data_id, unapplied) in unapplied_by_data {
606                let mut data_write = self.datas.take_write_for_apply(&data_id).await;
607                retractions.push(async move {
608                    let mut ret = Vec::new();
609                    for (unapplied, unapplied_ts) in unapplied {
610                        match unapplied {
611                            Unapplied::RegisterForget => {
612                                let () = crate::empty_caa(
613                                    || {
614                                        format!(
615                                            "data {:.9} register/forget fill",
616                                            data_id.to_string()
617                                        )
618                                    },
619                                    &mut data_write,
620                                    unapplied_ts.clone(),
621                                )
622                                .await;
623                            }
624                            Unapplied::Batch(batch_raws) => {
625                                let batch_raws = batch_raws
626                                    .into_iter()
627                                    .map(|batch_raw| batch_raw.as_slice())
628                                    .collect();
629                                crate::apply_caa(
630                                    &mut data_write,
631                                    &batch_raws,
632                                    unapplied_ts.clone(),
633                                )
634                                .await;
635                                for batch_raw in batch_raws {
636                                    // NB: Protos are not guaranteed to exactly roundtrip the
637                                    // encoded bytes, so we intentionally use the raw batch so that
638                                    // it definitely retracts.
639                                    ret.push((
640                                        batch_raw.to_vec(),
641                                        (T::encode(unapplied_ts), data_id),
642                                    ));
643                                }
644                            }
645                        }
646                    }
647                    (data_write, ret)
648                });
649            }
650            let retractions = retractions.collect::<Vec<_>>().await;
651            let retractions = retractions
652                .into_iter()
653                .flat_map(|(data_write, retractions)| {
654                    self.datas.put_write_for_apply(data_write);
655                    retractions
656                })
657                .collect();
658
659            // Remove all the applied registers.
660            self.txns_cache.mark_register_applied(ts);
661
662            debug!("apply_le {:?} success", ts);
663            Tidy { retractions }
664        })
665        .await
666    }
667
668    /// Commits the tidy work at the given time.
669    ///
670    /// Mostly a helper to make it obvious that we can throw away the apply work
671    /// (and not get into an infinite cycle of tidy->apply->tidy).
672    #[cfg(test)]
673    pub async fn tidy_at(&mut self, tidy_ts: T, tidy: Tidy) -> Result<(), T> {
674        debug!("tidy at {:?}", tidy_ts);
675
676        let mut txn = self.begin();
677        txn.tidy(tidy);
678        // We just constructed this txn, so it couldn't have committed any
679        // batches, and thus there's nothing to apply. We're free to throw it
680        // away.
681        let apply = txn.commit_at(self, tidy_ts.clone()).await?;
682        assert!(apply.is_empty());
683
684        debug!("tidy at {:?} success", tidy_ts);
685        Ok(())
686    }
687
688    /// Allows compaction to the txns shard as well as internal representations,
689    /// losing the ability to answer queries about times less_than since_ts.
690    ///
691    /// In practice, this will likely only be called from the singleton
692    /// controller process.
693    pub async fn compact_to(&mut self, mut since_ts: T) {
694        let op = &self.metrics.compact_to;
695        op.run(async {
696            tracing::debug!("compact_to {:?}", since_ts);
697            let _ = self.txns_cache.update_gt(&since_ts).await;
698
699            // NB: A critical invariant for how this all works is that we never
700            // allow the since of the txns shard to pass any unapplied writes, so
701            // reduce it as necessary.
702            let min_unapplied_ts = self.txns_cache.min_unapplied_ts();
703            if min_unapplied_ts < &since_ts {
704                since_ts.clone_from(min_unapplied_ts);
705            }
706            crate::cads::<T, C>(&mut self.txns_since, since_ts).await;
707        })
708        .await
709    }
710
711    /// Upgrade the version on the backing shard.
712    ///
713    /// In practice, this will likely only be called from the singleton
714    /// controller process.
715    pub async fn upgrade_version(&mut self) {
716        self.txns_since
717            .upgrade_version()
718            .await
719            .expect("invalid usage")
720    }
721
722    /// Returns the [ShardId] of the txns shard.
723    pub fn txns_id(&self) -> ShardId {
724        self.txns_write.shard_id()
725    }
726
727    /// Returns the [TxnsCache] used by this handle.
728    pub fn read_cache(&self) -> &TxnsCache<T, C> {
729        &self.txns_cache
730    }
731}
732
733/// A token representing maintenance writes (in particular, retractions) to the
734/// txns shard.
735///
736/// This can be written on its own with `TxnsHandle::tidy_at` or sidecar'd into
737/// a normal txn with [Txn::tidy].
738#[derive(Debug, Default)]
739pub struct Tidy {
740    pub(crate) retractions: BTreeMap<Vec<u8>, ([u8; 8], ShardId)>,
741}
742
743impl Tidy {
744    /// Merges the work represented by the other tidy into this one.
745    pub fn merge(&mut self, other: Tidy) {
746        self.retractions.extend(other.retractions)
747    }
748}
749
750/// A helper to make a more targeted mutable borrow of self.
751#[derive(Debug)]
752pub(crate) struct DataHandles<K: Codec, V: Codec, T, D> {
753    pub(crate) dyncfgs: ConfigSet,
754    pub(crate) client: Arc<PersistClient>,
755    /// See [DataWriteApply].
756    ///
757    /// This is lazily populated with the set of shards touched by `apply_le`.
758    data_write_for_apply: BTreeMap<ShardId, DataWriteApply<K, V, T, D>>,
759    /// See [DataWriteCommit].
760    ///
761    /// This contains the set of data shards registered but not yet forgotten
762    /// with this particular write handle.
763    ///
764    /// NB: In the common case, this and `_for_apply` will contain the same set
765    /// of shards, but this is not required. A shard can be in either and not
766    /// the other.
767    data_write_for_commit: BTreeMap<ShardId, DataWriteCommit<K, V, T, D>>,
768}
769
770impl<K, V, T, D> DataHandles<K, V, T, D>
771where
772    K: Debug + Codec,
773    V: Debug + Codec,
774    T: Timestamp + Lattice + TotalOrder + Codec64 + Sync,
775    D: Monoid + Ord + Codec64 + Send + Sync,
776{
777    async fn open_data_write_for_apply(&self, data_id: ShardId) -> DataWriteApply<K, V, T, D> {
778        let diagnostics = Diagnostics::from_purpose("txn data");
779        let schemas = self
780            .client
781            .latest_schema::<K, V, T, D>(data_id, diagnostics.clone())
782            .await
783            .expect("codecs have not changed");
784        let (key_schema, val_schema) = match schemas {
785            Some((_, key_schema, val_schema)) => (Arc::new(key_schema), Arc::new(val_schema)),
786            // We will always have at least one schema registered by the time we reach this point,
787            // because that is ensured at txn-registration time.
788            None => unreachable!("data shard {} should have a schema", data_id),
789        };
790        let wrapped = self
791            .client
792            .open_writer(data_id, key_schema, val_schema, diagnostics)
793            .await
794            .expect("schema shouldn't change");
795        DataWriteApply {
796            apply_ensure_schema_match: APPLY_ENSURE_SCHEMA_MATCH.handle(&self.dyncfgs),
797            client: Arc::clone(&self.client),
798            wrapped,
799        }
800    }
801
802    pub(crate) async fn take_write_for_apply(
803        &mut self,
804        data_id: &ShardId,
805    ) -> DataWriteApply<K, V, T, D> {
806        if let Some(data_write) = self.data_write_for_apply.remove(data_id) {
807            return data_write;
808        }
809        self.open_data_write_for_apply(*data_id).await
810    }
811
812    pub(crate) fn put_write_for_apply(&mut self, data_write: DataWriteApply<K, V, T, D>) {
813        self.data_write_for_apply
814            .insert(data_write.shard_id(), data_write);
815    }
816
817    pub(crate) fn take_write_for_commit(
818        &mut self,
819        data_id: &ShardId,
820    ) -> Option<DataWriteCommit<K, V, T, D>> {
821        self.data_write_for_commit.remove(data_id)
822    }
823
824    pub(crate) fn put_write_for_commit(&mut self, data_write: DataWriteCommit<K, V, T, D>) {
825        let prev = self
826            .data_write_for_commit
827            .insert(data_write.shard_id(), data_write);
828        assert!(prev.is_none());
829    }
830}
831
832/// A newtype wrapper around [WriteHandle] indicating that it has a real schema
833/// registered by the user.
834///
835/// The txn-wal user declares which schema they'd like to use for committing
836/// batches by passing it in (as part of the WriteHandle) in the call to
837/// register. This must be used to encode any new batches written. The wrapper
838/// helps us from accidentally mixing up the WriteHandles that we internally
839/// invent for applying the batches (which use a schema matching the one
840/// declared in the batch).
841#[derive(Debug)]
842pub(crate) struct DataWriteCommit<K: Codec, V: Codec, T, D>(pub(crate) WriteHandle<K, V, T, D>);
843
844impl<K: Codec, V: Codec, T, D> Deref for DataWriteCommit<K, V, T, D> {
845    type Target = WriteHandle<K, V, T, D>;
846
847    fn deref(&self) -> &Self::Target {
848        &self.0
849    }
850}
851
852impl<K: Codec, V: Codec, T, D> DerefMut for DataWriteCommit<K, V, T, D> {
853    fn deref_mut(&mut self) -> &mut Self::Target {
854        &mut self.0
855    }
856}
857
858/// A newtype wrapper around [WriteHandle] indicating that it can alter the
859/// schema its using to match the one in the batches being appended.
860///
861/// When a batch is committed to txn-wal, it contains metadata about which
862/// schemas were used to encode the data in it. Txn-wal then uses this info to
863/// make sure that in [TxnsHandle::apply_le], that the `compare_and_append` call
864/// happens on a handle with the same schema. This is accomplished by querying
865/// the persist schema registry.
866#[derive(Debug)]
867pub(crate) struct DataWriteApply<K: Codec, V: Codec, T, D> {
868    client: Arc<PersistClient>,
869    apply_ensure_schema_match: ConfigValHandle<bool>,
870    pub(crate) wrapped: WriteHandle<K, V, T, D>,
871}
872
873impl<K: Codec, V: Codec, T, D> Deref for DataWriteApply<K, V, T, D> {
874    type Target = WriteHandle<K, V, T, D>;
875
876    fn deref(&self) -> &Self::Target {
877        &self.wrapped
878    }
879}
880
881impl<K: Codec, V: Codec, T, D> DerefMut for DataWriteApply<K, V, T, D> {
882    fn deref_mut(&mut self) -> &mut Self::Target {
883        &mut self.wrapped
884    }
885}
886
887pub(crate) const APPLY_ENSURE_SCHEMA_MATCH: Config<bool> = Config::new(
888    "txn_wal_apply_ensure_schema_match",
889    true,
890    "CYA to skip updating write handle to batch schema in apply",
891);
892
893fn at_most_one_schema(
894    schemas: impl Iterator<Item = SchemaId>,
895) -> Result<Option<SchemaId>, (SchemaId, SchemaId)> {
896    let mut schema = None;
897    for s in schemas {
898        match schema {
899            None => schema = Some(s),
900            Some(x) if s != x => return Err((s, x)),
901            Some(_) => continue,
902        }
903    }
904    Ok(schema)
905}
906
907impl<K, V, T, D> DataWriteApply<K, V, T, D>
908where
909    K: Debug + Codec,
910    V: Debug + Codec,
911    T: Timestamp + Lattice + TotalOrder + Codec64 + Sync,
912    D: Monoid + Ord + Codec64 + Send + Sync,
913{
914    pub(crate) async fn maybe_replace_with_batch_schema(&mut self, batches: &[Batch<K, V, T, D>]) {
915        // TODO: Remove this once everything is rolled out and we're sure it's
916        // not going to cause any issues.
917        if !self.apply_ensure_schema_match.get() {
918            return;
919        }
920        let batch_schema = at_most_one_schema(batches.iter().flat_map(|x| x.schemas()));
921        let batch_schema = batch_schema.unwrap_or_else(|_| {
922            panic!(
923                "txn-wal uses at most one schema to commit batches, got: {:?}",
924                batches.iter().flat_map(|x| x.schemas()).collect::<Vec<_>>()
925            )
926        });
927        let (batch_schema, handle_schema) = match (batch_schema, self.wrapped.schema_id()) {
928            (Some(batch_schema), Some(handle_schema)) if batch_schema != handle_schema => {
929                (batch_schema, handle_schema)
930            }
931            _ => return,
932        };
933
934        let data_id = self.shard_id();
935        let diagnostics = Diagnostics::from_purpose("txn data");
936        let (key_schema, val_schema) = self
937            .client
938            .get_schema::<K, V, T, D>(data_id, batch_schema, diagnostics.clone())
939            .await
940            .expect("codecs shouldn't change")
941            .expect("id must have been registered to create this batch");
942        let new_data_write = self
943            .client
944            .open_writer(
945                self.shard_id(),
946                Arc::new(key_schema),
947                Arc::new(val_schema),
948                diagnostics,
949            )
950            .await
951            .expect("codecs shouldn't change");
952        tracing::info!(
953            "updated {} write handle from {} to {} to apply batches",
954            data_id,
955            handle_schema,
956            batch_schema
957        );
958        assert_eq!(new_data_write.schema_id(), Some(batch_schema));
959        self.wrapped = new_data_write;
960    }
961}
962
963#[cfg(test)]
964mod tests {
965    use std::time::{Duration, UNIX_EPOCH};
966
967    use differential_dataflow::Hashable;
968    use futures::future::BoxFuture;
969    use mz_ore::assert_none;
970    use mz_ore::cast::CastFrom;
971    use mz_ore::collections::CollectionExt;
972    use mz_ore::metrics::MetricsRegistry;
973    use mz_persist_client::PersistLocation;
974    use mz_persist_client::cache::PersistClientCache;
975    use mz_persist_client::cfg::RetryParameters;
976    use rand::rngs::SmallRng;
977    use rand::{RngCore, SeedableRng};
978    use timely::progress::Antichain;
979    use tokio::sync::oneshot;
980    use tracing::{Instrument, info, info_span};
981
982    use crate::operator::DataSubscribe;
983    use crate::tests::{CommitLog, reader, write_directly, writer};
984
985    use super::*;
986
987    impl TxnsHandle<String, (), u64, i64, TxnsCodecDefault> {
988        pub(crate) async fn expect_open(client: PersistClient) -> Self {
989            Self::expect_open_id(client, ShardId::new()).await
990        }
991
992        pub(crate) async fn expect_open_id(client: PersistClient, txns_id: ShardId) -> Self {
993            let dyncfgs = crate::all_dyncfgs(client.dyncfgs().clone());
994            Self::open(
995                0,
996                client,
997                dyncfgs,
998                Arc::new(Metrics::new(&MetricsRegistry::new())),
999                txns_id,
1000                Opaque::encode(&0u64),
1001            )
1002            .await
1003        }
1004
1005        pub(crate) fn new_log(&self) -> CommitLog {
1006            CommitLog::new((*self.datas.client).clone(), self.txns_id())
1007        }
1008
1009        pub(crate) async fn expect_register(&mut self, register_ts: u64) -> ShardId {
1010            self.expect_registers(register_ts, 1).await.into_element()
1011        }
1012
1013        pub(crate) async fn expect_registers(
1014            &mut self,
1015            register_ts: u64,
1016            amount: usize,
1017        ) -> Vec<ShardId> {
1018            let data_ids: Vec<_> = (0..amount).map(|_| ShardId::new()).collect();
1019            let mut writers = Vec::new();
1020            for data_id in &data_ids {
1021                writers.push(writer(&self.datas.client, *data_id).await);
1022            }
1023            self.register(register_ts, writers).await.unwrap();
1024            data_ids
1025        }
1026
1027        pub(crate) async fn expect_commit_at(
1028            &mut self,
1029            commit_ts: u64,
1030            data_id: ShardId,
1031            keys: &[&str],
1032            log: &CommitLog,
1033        ) -> Tidy {
1034            let mut txn = self.begin();
1035            for key in keys {
1036                txn.write(&data_id, (*key).into(), (), 1).await;
1037            }
1038            let tidy = txn
1039                .commit_at(self, commit_ts)
1040                .await
1041                .unwrap()
1042                .apply(self)
1043                .await;
1044            for key in keys {
1045                log.record((data_id, (*key).into(), commit_ts, 1));
1046            }
1047            tidy
1048        }
1049    }
1050
1051    #[mz_ore::test(tokio::test)]
1052    #[cfg_attr(miri, ignore)] // too slow
1053    async fn register_at() {
1054        let client = PersistClient::new_for_tests().await;
1055        let mut txns = TxnsHandle::expect_open(client.clone()).await;
1056        let log = txns.new_log();
1057        let d0 = txns.expect_register(2).await;
1058
1059        // Register a second time is a no-op (idempotent).
1060        txns.register(3, [writer(&client, d0).await]).await.unwrap();
1061
1062        // Cannot register a new data shard at an already closed off time. An
1063        // error is returned with the first time that a registration would
1064        // succeed.
1065        let d1 = ShardId::new();
1066        assert_eq!(
1067            txns.register(2, [writer(&client, d1).await])
1068                .await
1069                .unwrap_err(),
1070            4
1071        );
1072
1073        // Can still register after txns have been committed.
1074        txns.expect_commit_at(4, d0, &["foo"], &log).await;
1075        txns.register(5, [writer(&client, d1).await]).await.unwrap();
1076
1077        // We can also register some new and some already registered shards.
1078        let d2 = ShardId::new();
1079        txns.register(6, [writer(&client, d0).await, writer(&client, d2).await])
1080            .await
1081            .unwrap();
1082
1083        let () = log.assert_snapshot(d0, 6).await;
1084        let () = log.assert_snapshot(d1, 6).await;
1085    }
1086
1087    /// A sanity check that CommitLog catches an incorrect usage (proxy for a
1088    /// bug that looks like an incorrect usage).
1089    #[mz_ore::test(tokio::test)]
1090    #[cfg_attr(miri, ignore)] // too slow
1091    #[should_panic(expected = "left: [(\"foo\", 2, 1)]\n right: [(\"foo\", 2, 2)]")]
1092    async fn incorrect_usage_register_write_same_time() {
1093        let client = PersistClient::new_for_tests().await;
1094        let mut txns = TxnsHandle::expect_open(client.clone()).await;
1095        let log = txns.new_log();
1096        let d0 = txns.expect_register(1).await;
1097        let mut d0_write = writer(&client, d0).await;
1098
1099        // Commit a write at ts 2...
1100        let mut txn = txns.begin_test();
1101        txn.write(&d0, "foo".into(), (), 1).await;
1102        let apply = txn.commit_at(&mut txns, 2).await.unwrap();
1103        log.record_txn(2, &txn);
1104        // ... and (incorrectly) also write to the shard normally at ts 2.
1105        let () = d0_write
1106            .compare_and_append(
1107                &[(("foo".to_owned(), ()), 2, 1)],
1108                Antichain::from_elem(2),
1109                Antichain::from_elem(3),
1110            )
1111            .await
1112            .unwrap()
1113            .unwrap();
1114        log.record((d0, "foo".into(), 2, 1));
1115        apply.apply(&mut txns).await;
1116
1117        // Verify that CommitLog catches this.
1118        log.assert_snapshot(d0, 2).await;
1119    }
1120
1121    #[mz_ore::test(tokio::test)]
1122    #[cfg_attr(miri, ignore)] // too slow
1123    async fn forget_at() {
1124        let client = PersistClient::new_for_tests().await;
1125        let mut txns = TxnsHandle::expect_open(client.clone()).await;
1126        let log = txns.new_log();
1127
1128        // Can forget a data_shard that has not been registered.
1129        txns.forget(1, [ShardId::new()]).await.unwrap();
1130
1131        // Can forget multiple data_shards that have not been registered.
1132        txns.forget(2, (0..5).map(|_| ShardId::new()))
1133            .await
1134            .unwrap();
1135
1136        // Can forget a registered shard.
1137        let d0 = txns.expect_register(3).await;
1138        txns.forget(4, [d0]).await.unwrap();
1139
1140        // Can forget multiple registered shards.
1141        let ds = txns.expect_registers(5, 5).await;
1142        txns.forget(6, ds.clone()).await.unwrap();
1143
1144        // Forget is idempotent.
1145        txns.forget(7, [d0]).await.unwrap();
1146        txns.forget(8, ds.clone()).await.unwrap();
1147
1148        // Cannot forget at an already closed off time. An error is returned
1149        // with the first time that a registration would succeed.
1150        let d1 = txns.expect_register(9).await;
1151        assert_eq!(txns.forget(9, [d1]).await.unwrap_err(), 10);
1152
1153        // Write to txns and to d0 directly.
1154        let mut d0_write = writer(&client, d0).await;
1155        txns.expect_commit_at(10, d1, &["d1"], &log).await;
1156        let updates = [(("d0".to_owned(), ()), 10, 1)];
1157        d0_write
1158            .compare_and_append(&updates, d0_write.shared_upper(), Antichain::from_elem(11))
1159            .await
1160            .unwrap()
1161            .unwrap();
1162        log.record((d0, "d0".into(), 10, 1));
1163
1164        // Can register and forget an already registered and forgotten shard.
1165        txns.register(11, [writer(&client, d0).await])
1166            .await
1167            .unwrap();
1168        let mut forget_expected = vec![d0, d1];
1169        forget_expected.sort();
1170        assert_eq!(txns.forget_all(12).await.unwrap().0, forget_expected);
1171
1172        // Close shard to writes
1173        d0_write
1174            .compare_and_append_batch(&mut [], d0_write.shared_upper(), Antichain::new(), true)
1175            .await
1176            .unwrap()
1177            .unwrap();
1178
1179        let () = log.assert_snapshot(d0, 12).await;
1180        let () = log.assert_snapshot(d1, 12).await;
1181
1182        for di in ds {
1183            let mut di_write = writer(&client, di).await;
1184
1185            // Close shards to writes
1186            di_write
1187                .compare_and_append_batch(&mut [], di_write.shared_upper(), Antichain::new(), true)
1188                .await
1189                .unwrap()
1190                .unwrap();
1191
1192            let () = log.assert_snapshot(di, 8).await;
1193        }
1194    }
1195
1196    #[mz_ore::test(tokio::test)]
1197    #[cfg_attr(miri, ignore)] // too slow
1198    async fn register_forget() {
1199        async fn step_some_past(subs: &mut Vec<DataSubscribe>, ts: u64) {
1200            for (idx, sub) in subs.iter_mut().enumerate() {
1201                // Only step some of them to try to maximize edge cases.
1202                if usize::cast_from(ts) % (idx + 1) == 0 {
1203                    async {
1204                        info!("stepping sub {} past {}", idx, ts);
1205                        sub.step_past(ts).await;
1206                    }
1207                    .instrument(info_span!("sub", idx))
1208                    .await;
1209                }
1210            }
1211        }
1212
1213        let client = PersistClient::new_for_tests().await;
1214        let mut txns = TxnsHandle::expect_open(client.clone()).await;
1215        let log = txns.new_log();
1216        let d0 = ShardId::new();
1217        let mut d0_write = writer(&client, d0).await;
1218        let mut subs = Vec::new();
1219
1220        // Loop for a while doing the following:
1221        // - Write directly to some time before register ts
1222        // - Register
1223        // - Write via txns
1224        // - Forget
1225        //
1226        // After each step, make sure that a subscription started at that time
1227        // works and that all subscriptions can be stepped through the expected
1228        // timestamp.
1229        let mut ts = 0;
1230        while ts < 32 {
1231            subs.push(txns.read_cache().expect_subscribe(&client, d0, ts));
1232            ts += 1;
1233            info!("{} direct", ts);
1234            txns.begin().commit_at(&mut txns, ts).await.unwrap();
1235            write_directly(ts, &mut d0_write, &[&format!("d{}", ts)], &log).await;
1236            step_some_past(&mut subs, ts).await;
1237            if ts % 11 == 0 {
1238                txns.compact_to(ts).await;
1239            }
1240
1241            subs.push(txns.read_cache().expect_subscribe(&client, d0, ts));
1242            ts += 1;
1243            info!("{} register", ts);
1244            txns.register(ts, [writer(&client, d0).await])
1245                .await
1246                .unwrap();
1247            step_some_past(&mut subs, ts).await;
1248            if ts % 11 == 0 {
1249                txns.compact_to(ts).await;
1250            }
1251
1252            subs.push(txns.read_cache().expect_subscribe(&client, d0, ts));
1253            ts += 1;
1254            info!("{} txns", ts);
1255            txns.expect_commit_at(ts, d0, &[&format!("t{}", ts)], &log)
1256                .await;
1257            step_some_past(&mut subs, ts).await;
1258            if ts % 11 == 0 {
1259                txns.compact_to(ts).await;
1260            }
1261
1262            subs.push(txns.read_cache().expect_subscribe(&client, d0, ts));
1263            ts += 1;
1264            info!("{} forget", ts);
1265            txns.forget(ts, [d0]).await.unwrap();
1266            step_some_past(&mut subs, ts).await;
1267            if ts % 11 == 0 {
1268                txns.compact_to(ts).await;
1269            }
1270        }
1271
1272        // Check all the subscribes.
1273        for mut sub in subs.into_iter() {
1274            sub.step_past(ts).await;
1275            log.assert_eq(d0, sub.as_of, sub.progress(), sub.output().clone());
1276        }
1277    }
1278
1279    // Regression test for a bug encountered during initial development:
1280    // - task 0 commits to a data shard at ts T
1281    // - before task 0 can unblock T for reads, task 1 tries to register it at
1282    //   time T (ditto T+X), this does a caa of empty space, advancing the upper
1283    //   of the data shard to T+1
1284    // - task 0 attempts to caa in the batch, but finds out the upper is T+1 and
1285    //   assumes someone else did the work
1286    // - result: the write is lost
1287    #[mz_ore::test(tokio::test)]
1288    #[cfg_attr(miri, ignore)] // unsupported operation: returning ready events from epoll_wait is not yet implemented
1289    async fn race_data_shard_register_and_commit() {
1290        let client = PersistClient::new_for_tests().await;
1291        let mut txns = TxnsHandle::expect_open(client.clone()).await;
1292        let d0 = txns.expect_register(1).await;
1293
1294        let mut txn = txns.begin();
1295        txn.write(&d0, "foo".into(), (), 1).await;
1296        let commit_apply = txn.commit_at(&mut txns, 2).await.unwrap();
1297
1298        txns.register(3, [writer(&client, d0).await]).await.unwrap();
1299
1300        // Make sure that we can read empty at the register commit time even
1301        // before the txn commit apply.
1302        let actual = txns.txns_cache.expect_snapshot(&client, d0, 1).await;
1303        assert_eq!(actual, Vec::<String>::new());
1304
1305        commit_apply.apply(&mut txns).await;
1306        let actual = txns.txns_cache.expect_snapshot(&client, d0, 2).await;
1307        assert_eq!(actual, vec!["foo".to_owned()]);
1308    }
1309
1310    // A test that applies a batch of writes all at once.
1311    #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
1312    #[cfg_attr(miri, ignore)] // unsupported operation: returning ready events from epoll_wait is not yet implemented
1313    async fn apply_many_ts() {
1314        let client = PersistClient::new_for_tests().await;
1315        let mut txns = TxnsHandle::expect_open(client.clone()).await;
1316        let log = txns.new_log();
1317        let d0 = txns.expect_register(1).await;
1318
1319        for ts in 2..10 {
1320            let mut txn = txns.begin();
1321            txn.write(&d0, ts.to_string(), (), 1).await;
1322            let _apply = txn.commit_at(&mut txns, ts).await.unwrap();
1323            log.record((d0, ts.to_string(), ts, 1));
1324        }
1325        // This automatically runs the apply, which catches up all the previous
1326        // txns at once.
1327        txns.expect_commit_at(10, d0, &[], &log).await;
1328
1329        log.assert_snapshot(d0, 10).await;
1330    }
1331
1332    struct StressWorker {
1333        idx: usize,
1334        data_ids: Vec<ShardId>,
1335        txns: TxnsHandle<String, (), u64, i64>,
1336        log: CommitLog,
1337        tidy: Tidy,
1338        ts: u64,
1339        step: usize,
1340        rng: SmallRng,
1341        reads: Vec<(
1342            oneshot::Sender<u64>,
1343            ShardId,
1344            u64,
1345            mz_ore::task::JoinHandle<Vec<(String, u64, i64)>>,
1346        )>,
1347    }
1348
1349    impl StressWorker {
1350        pub async fn step(&mut self) {
1351            debug!(
1352                "stress {} step {} START ts={}",
1353                self.idx, self.step, self.ts
1354            );
1355            let data_id =
1356                self.data_ids[usize::cast_from(self.rng.next_u64()) % self.data_ids.len()];
1357            match self.rng.next_u64() % 6 {
1358                0 => self.write(data_id).await,
1359                // The register and forget impls intentionally don't switch on
1360                // whether it's already registered to stress idempotence.
1361                1 => self.register(data_id).await,
1362                2 => self.forget(data_id).await,
1363                3 => {
1364                    debug!("stress update {:.9} to {}", data_id.to_string(), self.ts);
1365                    let _ = self.txns.txns_cache.update_ge(&self.ts).await;
1366                }
1367                4 => self.start_read(data_id),
1368                5 => self.start_read(data_id),
1369                _ => unreachable!(""),
1370            }
1371            debug!("stress {} step {} DONE ts={}", self.idx, self.step, self.ts);
1372            self.step += 1;
1373        }
1374
1375        fn key(&self) -> String {
1376            format!("w{}s{}", self.idx, self.step)
1377        }
1378
1379        async fn registered_at_progress_ts(&mut self, data_id: ShardId) -> bool {
1380            self.ts = *self.txns.txns_cache.update_ge(&self.ts).await;
1381            self.txns
1382                .txns_cache
1383                .registered_at_progress(&data_id, &self.ts)
1384        }
1385
1386        // Writes to the given data shard, either via txns if it's registered or
1387        // directly if it's not.
1388        async fn write(&mut self, data_id: ShardId) {
1389            // Make sure to keep the registered_at_ts call _inside_ the retry
1390            // loop, because a data shard might switch between registered or not
1391            // registered as the loop advances through timestamps.
1392            self.retry_ts_err(&mut |w: &mut StressWorker| {
1393                Box::pin(async move {
1394                    if w.registered_at_progress_ts(data_id).await {
1395                        w.write_via_txns(data_id).await
1396                    } else {
1397                        w.write_direct(data_id).await
1398                    }
1399                })
1400            })
1401            .await
1402        }
1403
1404        async fn write_via_txns(&mut self, data_id: ShardId) -> Result<(), u64> {
1405            debug!(
1406                "stress write_via_txns {:.9} at {}",
1407                data_id.to_string(),
1408                self.ts
1409            );
1410            // HACK: Normally, we'd make sure that this particular handle had
1411            // registered the data shard before writing to it, but that would
1412            // consume a ts and isn't quite how we want `write_via_txns` to
1413            // work. Work around that by setting a write handle (with a schema
1414            // that we promise is correct) in the right place.
1415            if !self.txns.datas.data_write_for_commit.contains_key(&data_id) {
1416                let x = writer(&self.txns.datas.client, data_id).await;
1417                self.txns
1418                    .datas
1419                    .data_write_for_commit
1420                    .insert(data_id, DataWriteCommit(x));
1421            }
1422            let mut txn = self.txns.begin_test();
1423            txn.tidy(std::mem::take(&mut self.tidy));
1424            txn.write(&data_id, self.key(), (), 1).await;
1425            let apply = txn.commit_at(&mut self.txns, self.ts).await?;
1426            debug!(
1427                "log {:.9} {} at {}",
1428                data_id.to_string(),
1429                self.key(),
1430                self.ts
1431            );
1432            self.log.record_txn(self.ts, &txn);
1433            if self.rng.next_u64() % 3 == 0 {
1434                self.tidy.merge(apply.apply(&mut self.txns).await);
1435            }
1436            Ok(())
1437        }
1438
1439        async fn write_direct(&mut self, data_id: ShardId) -> Result<(), u64> {
1440            debug!(
1441                "stress write_direct {:.9} at {}",
1442                data_id.to_string(),
1443                self.ts
1444            );
1445            // First write an empty txn to ensure that the shard isn't
1446            // registered at this ts by someone else.
1447            self.txns.begin().commit_at(&mut self.txns, self.ts).await?;
1448
1449            let mut write = writer(&self.txns.datas.client, data_id).await;
1450            let mut current = write.shared_upper().into_option().unwrap();
1451            loop {
1452                if !(current <= self.ts) {
1453                    return Err(current);
1454                }
1455                let key = self.key();
1456                let updates = [((&key, &()), &self.ts, 1)];
1457                let res = crate::small_caa(
1458                    || format!("data {:.9} direct", data_id.to_string()),
1459                    &mut write,
1460                    &updates,
1461                    current,
1462                    self.ts + 1,
1463                )
1464                .await;
1465                match res {
1466                    Ok(()) => {
1467                        debug!("log {:.9} {} at {}", data_id.to_string(), key, self.ts);
1468                        self.log.record((data_id, key, self.ts, 1));
1469                        return Ok(());
1470                    }
1471                    Err(new_current) => current = new_current,
1472                }
1473            }
1474        }
1475
1476        async fn register(&mut self, data_id: ShardId) {
1477            self.retry_ts_err(&mut |w: &mut StressWorker| {
1478                debug!("stress register {:.9} at {}", data_id.to_string(), w.ts);
1479                Box::pin(async move {
1480                    let data_write = writer(&w.txns.datas.client, data_id).await;
1481                    let _ = w.txns.register(w.ts, [data_write]).await?;
1482                    Ok(())
1483                })
1484            })
1485            .await
1486        }
1487
1488        async fn forget(&mut self, data_id: ShardId) {
1489            self.retry_ts_err(&mut |w: &mut StressWorker| {
1490                debug!("stress forget {:.9} at {}", data_id.to_string(), w.ts);
1491                Box::pin(async move { w.txns.forget(w.ts, [data_id]).await.map(|_| ()) })
1492            })
1493            .await
1494        }
1495
1496        fn start_read(&mut self, data_id: ShardId) {
1497            debug!(
1498                "stress start_read {:.9} at {}",
1499                data_id.to_string(),
1500                self.ts
1501            );
1502            let client = (*self.txns.datas.client).clone();
1503            let txns_id = self.txns.txns_id();
1504            let as_of = self.ts;
1505            debug!("start_read {:.9} as_of {}", data_id.to_string(), as_of);
1506            let (tx, mut rx) = oneshot::channel();
1507            let subscribe = mz_ore::task::spawn_blocking(
1508                || format!("{:.9}-{}", data_id.to_string(), as_of),
1509                move || {
1510                    let mut subscribe = DataSubscribe::new(
1511                        "test",
1512                        client,
1513                        txns_id,
1514                        data_id,
1515                        as_of,
1516                        Antichain::new(),
1517                    );
1518                    let data_id = format!("{:.9}", data_id.to_string());
1519                    let _guard = info_span!("read_worker", %data_id, as_of).entered();
1520                    loop {
1521                        subscribe.worker.step_or_park(None);
1522                        subscribe.capture_output();
1523                        let until = match rx.try_recv() {
1524                            Ok(ts) => ts,
1525                            Err(oneshot::error::TryRecvError::Empty) => {
1526                                continue;
1527                            }
1528                            Err(oneshot::error::TryRecvError::Closed) => 0,
1529                        };
1530                        while subscribe.progress() < until {
1531                            subscribe.worker.step_or_park(None);
1532                            subscribe.capture_output();
1533                        }
1534                        return subscribe.output().clone();
1535                    }
1536                },
1537            );
1538            self.reads.push((tx, data_id, as_of, subscribe));
1539        }
1540
1541        async fn retry_ts_err<W>(&mut self, work_fn: &mut W)
1542        where
1543            W: for<'b> FnMut(&'b mut Self) -> BoxFuture<'b, Result<(), u64>>,
1544        {
1545            loop {
1546                match work_fn(self).await {
1547                    Ok(ret) => return ret,
1548                    Err(new_ts) => self.ts = new_ts,
1549                }
1550            }
1551        }
1552    }
1553
1554    #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
1555    #[cfg_attr(miri, ignore)] // unsupported operation: returning ready events from epoll_wait is not yet implemented
1556    async fn stress_correctness() {
1557        const NUM_DATA_SHARDS: usize = 2;
1558        const NUM_WORKERS: usize = 2;
1559        const NUM_STEPS_PER_WORKER: usize = 100;
1560        let seed = UNIX_EPOCH.elapsed().unwrap().hashed();
1561        eprintln!("using seed {}", seed);
1562
1563        let mut clients = PersistClientCache::new_no_metrics();
1564        // We disable pubsub below, so retune the listen retries (pubsub
1565        // fallback) to keep the test speedy.
1566        clients
1567            .cfg()
1568            .set_next_listen_batch_retryer(RetryParameters {
1569                fixed_sleep: Duration::ZERO,
1570                initial_backoff: Duration::from_millis(1),
1571                multiplier: 1,
1572                clamp: Duration::from_millis(1),
1573            });
1574        let client = clients.open(PersistLocation::new_in_mem()).await.unwrap();
1575        let mut txns = TxnsHandle::expect_open(client.clone()).await;
1576        let log = txns.new_log();
1577        let data_ids = (0..NUM_DATA_SHARDS)
1578            .map(|_| ShardId::new())
1579            .collect::<Vec<_>>();
1580        let data_writes = data_ids
1581            .iter()
1582            .map(|data_id| writer(&client, *data_id))
1583            .collect::<FuturesUnordered<_>>()
1584            .collect::<Vec<_>>()
1585            .await;
1586        let _data_sinces = data_ids
1587            .iter()
1588            .map(|data_id| reader(&client, *data_id))
1589            .collect::<FuturesUnordered<_>>()
1590            .collect::<Vec<_>>()
1591            .await;
1592        let register_ts = 1;
1593        txns.register(register_ts, data_writes).await.unwrap();
1594
1595        let mut workers = Vec::new();
1596        for idx in 0..NUM_WORKERS {
1597            // Clear the state cache between each client to maximally disconnect
1598            // them from each other.
1599            clients.clear_state_cache();
1600            let client = clients.open(PersistLocation::new_in_mem()).await.unwrap();
1601            let mut worker = StressWorker {
1602                idx,
1603                log: log.clone(),
1604                txns: TxnsHandle::expect_open_id(client.clone(), txns.txns_id()).await,
1605                data_ids: data_ids.clone(),
1606                tidy: Tidy::default(),
1607                ts: register_ts,
1608                step: 0,
1609                rng: SmallRng::seed_from_u64(seed.wrapping_add(u64::cast_from(idx))),
1610                reads: Vec::new(),
1611            };
1612            let worker = async move {
1613                while worker.step < NUM_STEPS_PER_WORKER {
1614                    worker.step().await;
1615                }
1616                (worker.ts, worker.reads)
1617            }
1618            .instrument(info_span!("stress_worker", idx));
1619            workers.push(mz_ore::task::spawn(|| format!("worker-{}", idx), worker));
1620        }
1621
1622        let mut max_ts = 0;
1623        let mut reads = Vec::new();
1624        for worker in workers {
1625            let (t, mut r) = worker.await;
1626            max_ts = std::cmp::max(max_ts, t);
1627            reads.append(&mut r);
1628        }
1629
1630        // Run all of the following in a timeout to make hangs easier to debug.
1631        tokio::time::timeout(Duration::from_secs(30), async {
1632            info!("finished with max_ts of {}", max_ts);
1633            txns.apply_le(&max_ts).await;
1634            for data_id in data_ids {
1635                info!("reading data shard {}", data_id);
1636                log.assert_snapshot(data_id, max_ts)
1637                    .instrument(info_span!("read_data", data_id = format!("{:.9}", data_id)))
1638                    .await;
1639            }
1640            info!("now waiting for reads {}", max_ts);
1641            for (tx, data_id, as_of, subscribe) in reads {
1642                let _ = tx.send(max_ts + 1);
1643                let output = subscribe.await;
1644                log.assert_eq(data_id, as_of, max_ts + 1, output);
1645            }
1646        })
1647        .await
1648        .unwrap();
1649    }
1650
1651    #[mz_ore::test(tokio::test)]
1652    #[cfg_attr(miri, ignore)] // unsupported operation: returning ready events from epoll_wait is not yet implemented
1653    async fn advance_physical_uppers_past() {
1654        let client = PersistClient::new_for_tests().await;
1655        let mut txns = TxnsHandle::expect_open(client.clone()).await;
1656        let log = txns.new_log();
1657        let d0 = txns.expect_register(1).await;
1658        let mut d0_write = writer(&client, d0).await;
1659        let d1 = txns.expect_register(2).await;
1660        let mut d1_write = writer(&client, d1).await;
1661
1662        assert_eq!(d0_write.fetch_recent_upper().await.elements(), &[2]);
1663        assert_eq!(d1_write.fetch_recent_upper().await.elements(), &[3]);
1664
1665        // Normal `apply` (used by expect_commit_at) does not advance the
1666        // physical upper of data shards that were not involved in the txn (lazy
1667        // upper). d1 is not involved in this txn so stays where it is.
1668        txns.expect_commit_at(3, d0, &["0-2"], &log).await;
1669        assert_eq!(d0_write.fetch_recent_upper().await.elements(), &[4]);
1670        assert_eq!(d1_write.fetch_recent_upper().await.elements(), &[3]);
1671
1672        // d0 is not involved in this txn so stays where it is.
1673        txns.expect_commit_at(4, d1, &["1-3"], &log).await;
1674        assert_eq!(d0_write.fetch_recent_upper().await.elements(), &[4]);
1675        assert_eq!(d1_write.fetch_recent_upper().await.elements(), &[5]);
1676
1677        log.assert_snapshot(d0, 4).await;
1678        log.assert_snapshot(d1, 4).await;
1679    }
1680
1681    #[mz_ore::test(tokio::test)]
1682    #[cfg_attr(miri, ignore)]
1683    #[allow(clippy::unnecessary_get_then_check)] // Makes it less readable.
1684    async fn schemas() {
1685        let client = PersistClient::new_for_tests().await;
1686        let mut txns0 = TxnsHandle::expect_open(client.clone()).await;
1687        let mut txns1 = TxnsHandle::expect_open_id(client.clone(), txns0.txns_id()).await;
1688        let log = txns0.new_log();
1689        let d0 = txns0.expect_register(1).await;
1690
1691        // The register call happened on txns0, which means it has a real schema
1692        // and can commit batches.
1693        assert!(txns0.datas.data_write_for_commit.get(&d0).is_some());
1694        let mut txn = txns0.begin_test();
1695        txn.write(&d0, "foo".into(), (), 1).await;
1696        let apply = txn.commit_at(&mut txns0, 2).await.unwrap();
1697        log.record_txn(2, &txn);
1698
1699        // We can use handle without a register call to apply a committed txn.
1700        assert!(txns1.datas.data_write_for_commit.get(&d0).is_none());
1701        let _tidy = apply.apply(&mut txns1).await;
1702
1703        // However, it cannot commit batches.
1704        assert!(txns1.datas.data_write_for_commit.get(&d0).is_none());
1705        let res = mz_ore::task::spawn(|| "test", async move {
1706            let mut txn = txns1.begin();
1707            txn.write(&d0, "bar".into(), (), 1).await;
1708            // This panics.
1709            let _ = txn.commit_at(&mut txns1, 3).await;
1710        })
1711        .into_tokio_handle();
1712        assert!(res.await.is_err());
1713
1714        // Forgetting the data shard removes it, so we don't leave the schema
1715        // sitting around.
1716        assert!(txns0.datas.data_write_for_commit.get(&d0).is_some());
1717        txns0.forget(3, [d0]).await.unwrap();
1718        assert_none!(txns0.datas.data_write_for_commit.get(&d0));
1719
1720        // Forget is idempotent.
1721        assert_none!(txns0.datas.data_write_for_commit.get(&d0));
1722        txns0.forget(4, [d0]).await.unwrap();
1723        assert_none!(txns0.datas.data_write_for_commit.get(&d0));
1724
1725        // We can register it again and commit again.
1726        assert_none!(txns0.datas.data_write_for_commit.get(&d0));
1727        txns0
1728            .register(5, [writer(&client, d0).await])
1729            .await
1730            .unwrap();
1731        assert!(txns0.datas.data_write_for_commit.get(&d0).is_some());
1732        txns0.expect_commit_at(6, d0, &["baz"], &log).await;
1733
1734        log.assert_snapshot(d0, 6).await;
1735    }
1736}