mz_txn_wal/
txns.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! An interface for atomic multi-shard writes.
11
12use std::collections::BTreeMap;
13use std::fmt::Debug;
14use std::ops::{Deref, DerefMut};
15use std::sync::Arc;
16
17use differential_dataflow::difference::Monoid;
18use differential_dataflow::lattice::Lattice;
19use futures::StreamExt;
20use futures::stream::FuturesUnordered;
21use mz_dyncfg::{Config, ConfigSet, ConfigValHandle};
22use mz_ore::collections::HashSet;
23use mz_ore::instrument;
24use mz_persist_client::batch::Batch;
25use mz_persist_client::cfg::USE_CRITICAL_SINCE_TXN;
26use mz_persist_client::critical::SinceHandle;
27use mz_persist_client::write::WriteHandle;
28use mz_persist_client::{Diagnostics, PersistClient, ShardId};
29use mz_persist_types::schema::SchemaId;
30use mz_persist_types::txn::{TxnsCodec, TxnsEntry};
31use mz_persist_types::{Codec, Codec64, Opaque, StepForward};
32use timely::order::TotalOrder;
33use timely::progress::Timestamp;
34use tracing::debug;
35
36use crate::TxnsCodecDefault;
37use crate::metrics::Metrics;
38use crate::txn_cache::{TxnsCache, Unapplied};
39use crate::txn_write::Txn;
40
41/// An interface for atomic multi-shard writes.
42///
43/// This handle is acquired through [Self::open]. Any data shards must be
44/// registered with [Self::register] before use. Transactions are then started
45/// with [Self::begin].
46///
47/// # Implementation Details
48///
49/// The structure of the txns shard is `(ShardId, Vec<u8>)` updates.
50///
51/// The core mechanism is that a txn commits a set of transmittable persist
52/// _batch handles_ as `(ShardId, <opaque blob>)` pairs at a single timestamp.
53/// This contractually both commits the txn and advances the logical upper of
54/// _every_ data shard (not just the ones involved in the txn).
55///
56/// Example:
57///
58/// ```text
59/// // A txn to only d0 at ts=1
60/// (d0, <opaque blob A>, 1, 1)
61/// // A txn to d0 (two blobs) and d1 (one blob) at ts=4
62/// (d0, <opaque blob B>, 4, 1)
63/// (d0, <opaque blob C>, 4, 1)
64/// (d1, <opaque blob D>, 4, 1)
65/// ```
66///
67/// However, the new commit is not yet readable until the txn apply has run,
68/// which is expected to be promptly done by the committer, except in the event
69/// of a crash. This, in ts order, moves the batch handles into the data shards
70/// with a [compare_and_append_batch] (similar to how the multi-worker
71/// persist_sink works).
72///
73/// [compare_and_append_batch]:
74///     mz_persist_client::write::WriteHandle::compare_and_append_batch
75///
76/// Once apply is run, we "tidy" the txns shard by retracting the update adding
77/// the batch. As a result, the contents of the txns shard at any given
78/// timestamp is exactly the set of outstanding apply work (plus registrations,
79/// see below).
80///
81/// Example (building on the above):
82///
83/// ```text
84/// // Tidy for the first txn at ts=3
85/// (d0, <opaque blob A>, 3, -1)
86/// // Tidy for the second txn (the timestamps can be different for each
87/// // retraction in a txn, but don't need to be)
88/// (d0, <opaque blob B>, 5, -1)
89/// (d0, <opaque blob C>, 6, -1)
90/// (d1, <opaque blob D>, 6, -1)
91/// ```
92///
93/// To make it easy to reason about exactly which data shards are registered in
94/// the txn set at any given moment, the data shard is added to the set with a
95/// `(ShardId, <empty>)` pair. The data may not be read before the timestamp of
96/// the update (which starts at the time it was initialized, but it may later be
97/// forwarded).
98///
99/// Example (building on both of the above):
100///
101/// ```text
102/// // d0 and d1 were both initialized before they were used above
103/// (d0, <empty>, 0, 1)
104/// (d1, <empty>, 2, 1)
105/// ```
106#[derive(Debug)]
107pub struct TxnsHandle<K: Codec, V: Codec, T, D, O = u64, C: TxnsCodec = TxnsCodecDefault> {
108    pub(crate) metrics: Arc<Metrics>,
109    pub(crate) txns_cache: TxnsCache<T, C>,
110    pub(crate) txns_write: WriteHandle<C::Key, C::Val, T, i64>,
111    pub(crate) txns_since: SinceHandle<C::Key, C::Val, T, i64, O>,
112    pub(crate) datas: DataHandles<K, V, T, D>,
113}
114
115impl<K, V, T, D, O, C> TxnsHandle<K, V, T, D, O, C>
116where
117    K: Debug + Codec,
118    V: Debug + Codec,
119    T: Timestamp + Lattice + TotalOrder + StepForward + Codec64 + Sync,
120    D: Debug + Monoid + Ord + Codec64 + Send + Sync,
121    O: Opaque + Debug + Codec64,
122    C: TxnsCodec,
123{
124    /// Returns a [TxnsHandle] committing to the given txn shard.
125    ///
126    /// `txns_id` identifies which shard will be used as the txns WAL. MZ will
127    /// likely have one of these per env, used by all processes and the same
128    /// across restarts.
129    ///
130    /// This also does any (idempotent) initialization work: i.e. ensures that
131    /// the txn shard is readable at `init_ts` by appending an empty batch, if
132    /// necessary.
133    pub async fn open(
134        init_ts: T,
135        client: PersistClient,
136        dyncfgs: ConfigSet,
137        metrics: Arc<Metrics>,
138        txns_id: ShardId,
139    ) -> Self {
140        let (txns_key_schema, txns_val_schema) = C::schemas();
141        let (mut txns_write, txns_read) = client
142            .open(
143                txns_id,
144                Arc::new(txns_key_schema),
145                Arc::new(txns_val_schema),
146                Diagnostics {
147                    shard_name: "txns".to_owned(),
148                    handle_purpose: "commit txns".to_owned(),
149                },
150                USE_CRITICAL_SINCE_TXN.get(client.dyncfgs()),
151            )
152            .await
153            .expect("txns schema shouldn't change");
154        let txns_since = client
155            .open_critical_since(
156                txns_id,
157                // TODO: We likely need to use a different critical reader id
158                // for this if we want to be able to introspect it via SQL.
159                PersistClient::CONTROLLER_CRITICAL_SINCE,
160                Diagnostics {
161                    shard_name: "txns".to_owned(),
162                    handle_purpose: "commit txns".to_owned(),
163                },
164            )
165            .await
166            .expect("txns schema shouldn't change");
167        let txns_cache = TxnsCache::init(init_ts, txns_read, &mut txns_write).await;
168        TxnsHandle {
169            metrics,
170            txns_cache,
171            txns_write,
172            txns_since,
173            datas: DataHandles {
174                dyncfgs,
175                client: Arc::new(client),
176                data_write_for_apply: BTreeMap::new(),
177                data_write_for_commit: BTreeMap::new(),
178            },
179        }
180    }
181
182    /// Returns a new, empty transaction that can involve the data shards
183    /// registered with this handle.
184    pub fn begin(&self) -> Txn<K, V, T, D> {
185        // TODO: This is a method on the handle because we'll need WriteHandles
186        // once we start spilling to s3.
187        Txn::new()
188    }
189
190    /// Registers data shards for use with this txn set.
191    ///
192    /// A registration entry is written to the txn shard. If it is not possible
193    /// to register the data at the requested time, an Err will be returned with
194    /// the minimum time the data shards could be registered.
195    ///
196    /// This method is idempotent. Data shards currently registered at
197    /// `register_ts` will not be registered a second time. Specifically, this
198    /// method will return success when the most recent register ts `R` is
199    /// less_equal to `register_ts` AND there is no forget ts between `R` and
200    /// `register_ts`.
201    ///
202    /// As a side effect all txns <= register_ts are applied, including the
203    /// registration itself.
204    ///
205    /// **WARNING!** While a data shard is registered to the txn set, writing to
206    /// it directly (i.e. using a WriteHandle instead of the TxnHandle,
207    /// registering it with another txn shard) will lead to incorrectness,
208    /// undefined behavior, and (potentially sticky) panics.
209    #[instrument(level = "debug", fields(ts = ?register_ts))]
210    pub async fn register(
211        &mut self,
212        register_ts: T,
213        data_writes: impl IntoIterator<Item = WriteHandle<K, V, T, D>>,
214    ) -> Result<Tidy, T> {
215        let op = &Arc::clone(&self.metrics).register;
216        op.run(async {
217            let mut data_writes = data_writes.into_iter().collect::<Vec<_>>();
218
219            // The txns system requires that all participating data shards have a
220            // schema registered. Importantly, we must register a data shard's
221            // schema _before_ we publish it to the txns shard.
222            for data_write in &mut data_writes {
223                // Note that if this fails we'll bail out farther down in any case,
224                // so we might as well fail fast.
225                data_write
226                    .try_register_schema()
227                    .await
228                    .expect("schema should be registered");
229            }
230
231            let updates = data_writes
232                .iter()
233                .map(|data_write| {
234                    let data_id = data_write.shard_id();
235                    let entry = TxnsEntry::Register(data_id, T::encode(&register_ts));
236                    (data_id, C::encode(entry))
237                })
238                .collect::<Vec<_>>();
239            let data_ids_debug = || {
240                data_writes
241                    .iter()
242                    .map(|x| format!("{:.9}", x.shard_id().to_string()))
243                    .collect::<Vec<_>>()
244                    .join(" ")
245            };
246
247            let mut txns_upper = self
248                .txns_write
249                .shared_upper()
250                .into_option()
251                .expect("txns should not be closed");
252            loop {
253                txns_upper = self.txns_cache.update_ge(&txns_upper).await.clone();
254                // Figure out which are still unregistered as of `txns_upper`. Below
255                // we write conditionally on the upper being what we expect so than
256                // we can re-run this if anything changes from underneath us.
257                let updates = updates
258                    .iter()
259                    .flat_map(|(data_id, (key, val))| {
260                        let registered =
261                            self.txns_cache.registered_at_progress(data_id, &txns_upper);
262                        (!registered).then_some(((key, val), &register_ts, 1))
263                    })
264                    .collect::<Vec<_>>();
265                // If the txns_upper has passed register_ts, we can no longer write.
266                if register_ts < txns_upper {
267                    debug!(
268                        "txns register {} at {:?} mismatch current={:?}",
269                        data_ids_debug(),
270                        register_ts,
271                        txns_upper,
272                    );
273                    return Err(txns_upper);
274                }
275
276                let res = crate::small_caa(
277                    || format!("txns register {}", data_ids_debug()),
278                    &mut self.txns_write,
279                    &updates,
280                    txns_upper,
281                    register_ts.step_forward(),
282                )
283                .await;
284                match res {
285                    Ok(()) => {
286                        debug!(
287                            "txns register {} at {:?} success",
288                            data_ids_debug(),
289                            register_ts
290                        );
291                        break;
292                    }
293                    Err(new_txns_upper) => {
294                        self.metrics.register.retry_count.inc();
295                        txns_upper = new_txns_upper;
296                        continue;
297                    }
298                }
299            }
300            for data_write in data_writes {
301                // If we already have a write handle for a newer version of a table, don't replace
302                // it! Currently we only support adding columns to tables with a default value, so
303                // the latest/newest schema will always be the most complete.
304                //
305                // TODO(alter_table): Revisit when we support dropping columns.
306                match self.datas.data_write_for_commit.get(&data_write.shard_id()) {
307                    None => {
308                        self.datas
309                            .data_write_for_commit
310                            .insert(data_write.shard_id(), DataWriteCommit(data_write));
311                    }
312                    Some(previous) => {
313                        let new_schema_id = data_write.schema_id().expect("ensured above");
314
315                        if let Some(prev_schema_id) = previous.schema_id()
316                            && prev_schema_id > new_schema_id
317                        {
318                            mz_ore::soft_panic_or_log!(
319                                "tried registering a WriteHandle with an older SchemaId; \
320                                 prev_schema_id: {} new_schema_id: {} shard_id: {}",
321                                prev_schema_id,
322                                new_schema_id,
323                                previous.shard_id(),
324                            );
325                            continue;
326                        } else if previous.schema_id().is_none() {
327                            mz_ore::soft_panic_or_log!(
328                                "encountered data shard without a schema; shard_id: {}",
329                                previous.shard_id(),
330                            );
331                        }
332
333                        tracing::info!(
334                            prev_schema_id = ?previous.schema_id(),
335                            ?new_schema_id,
336                            shard_id = %previous.shard_id(),
337                            "replacing WriteHandle"
338                        );
339                        self.datas
340                            .data_write_for_commit
341                            .insert(data_write.shard_id(), DataWriteCommit(data_write));
342                    }
343                }
344            }
345            let tidy = self.apply_le(&register_ts).await;
346
347            Ok(tidy)
348        })
349        .await
350    }
351
352    /// Removes data shards from use with this txn set.
353    ///
354    /// The registration entry written to the txn shard is retracted. If it is
355    /// not possible to forget the data shard at the requested time, an Err will
356    /// be returned with the minimum time the data shards could be forgotten.
357    ///
358    /// This method is idempotent. Data shards currently forgotten at
359    /// `forget_ts` will not be forgotten a second time. Specifically, this
360    /// method will return success when the most recent forget ts (if any) `F`
361    /// is less_equal to `forget_ts` AND there is no register ts between `F` and
362    /// `forget_ts`.
363    ///
364    /// As a side effect all txns <= forget_ts are applied, including the
365    /// forget itself.
366    ///
367    /// **WARNING!** While a data shard is registered to the txn set, writing to
368    /// it directly (i.e. using a WriteHandle instead of the TxnHandle,
369    /// registering it with another txn shard) will lead to incorrectness,
370    /// undefined behavior, and (potentially sticky) panics.
371    #[instrument(level = "debug", fields(ts = ?forget_ts))]
372    pub async fn forget(
373        &mut self,
374        forget_ts: T,
375        data_ids: impl IntoIterator<Item = ShardId>,
376    ) -> Result<Tidy, T> {
377        let op = &Arc::clone(&self.metrics).forget;
378        op.run(async {
379            let data_ids = data_ids.into_iter().collect::<Vec<_>>();
380            let mut txns_upper = self
381                .txns_write
382                .shared_upper()
383                .into_option()
384                .expect("txns should not be closed");
385            loop {
386                txns_upper = self.txns_cache.update_ge(&txns_upper).await.clone();
387
388                let data_ids_debug = || {
389                    data_ids
390                        .iter()
391                        .map(|x| format!("{:.9}", x.to_string()))
392                        .collect::<Vec<_>>()
393                        .join(" ")
394                };
395                let updates = data_ids
396                    .iter()
397                    // Never registered or already forgotten. This could change in
398                    // `[txns_upper, forget_ts]` (due to races) so close off that
399                    // interval before returning, just don't write any updates.
400                    .filter(|data_id| self.txns_cache.registered_at_progress(data_id, &txns_upper))
401                    .map(|data_id| C::encode(TxnsEntry::Register(*data_id, T::encode(&forget_ts))))
402                    .collect::<Vec<_>>();
403                let updates = updates
404                    .iter()
405                    .map(|(key, val)| ((key, val), &forget_ts, -1))
406                    .collect::<Vec<_>>();
407
408                // If the txns_upper has passed forget_ts, we can no longer write.
409                if forget_ts < txns_upper {
410                    debug!(
411                        "txns forget {} at {:?} mismatch current={:?}",
412                        data_ids_debug(),
413                        forget_ts,
414                        txns_upper,
415                    );
416                    return Err(txns_upper);
417                }
418
419                // Ensure the latest writes for each shard has been applied, so we don't run into
420                // any issues trying to apply it later.
421                {
422                    let data_ids: HashSet<_> = data_ids.iter().cloned().collect();
423                    let data_latest_unapplied = self
424                        .txns_cache
425                        .unapplied_batches
426                        .values()
427                        .rev()
428                        .find(|(x, _, _)| data_ids.contains(x));
429                    if let Some((_, _, latest_write)) = data_latest_unapplied {
430                        debug!(
431                            "txns forget {} applying latest write {:?}",
432                            data_ids_debug(),
433                            latest_write,
434                        );
435                        let latest_write = latest_write.clone();
436                        let _tidy = self.apply_le(&latest_write).await;
437                    }
438                }
439                let res = crate::small_caa(
440                    || format!("txns forget {}", data_ids_debug()),
441                    &mut self.txns_write,
442                    &updates,
443                    txns_upper,
444                    forget_ts.step_forward(),
445                )
446                .await;
447                match res {
448                    Ok(()) => {
449                        debug!(
450                            "txns forget {} at {:?} success",
451                            data_ids_debug(),
452                            forget_ts
453                        );
454                        break;
455                    }
456                    Err(new_txns_upper) => {
457                        self.metrics.forget.retry_count.inc();
458                        txns_upper = new_txns_upper;
459                        continue;
460                    }
461                }
462            }
463
464            // Note: Ordering here matters, we want to generate the Tidy work _before_ removing the
465            // handle because the work will create a handle to the shard.
466            let tidy = self.apply_le(&forget_ts).await;
467            for data_id in &data_ids {
468                self.datas.data_write_for_commit.remove(data_id);
469            }
470
471            Ok(tidy)
472        })
473        .await
474    }
475
476    /// Forgets, at the given timestamp, every data shard that is registered.
477    /// Returns the ids of the forgotten shards. See [Self::forget].
478    #[instrument(level = "debug", fields(ts = ?forget_ts))]
479    pub async fn forget_all(&mut self, forget_ts: T) -> Result<(Vec<ShardId>, Tidy), T> {
480        let op = &Arc::clone(&self.metrics).forget_all;
481        op.run(async {
482            let mut txns_upper = self
483                .txns_write
484                .shared_upper()
485                .into_option()
486                .expect("txns should not be closed");
487            let registered = loop {
488                txns_upper = self.txns_cache.update_ge(&txns_upper).await.clone();
489
490                let registered = self.txns_cache.all_registered_at_progress(&txns_upper);
491                let data_ids_debug = || {
492                    registered
493                        .iter()
494                        .map(|x| format!("{:.9}", x.to_string()))
495                        .collect::<Vec<_>>()
496                        .join(" ")
497                };
498                let updates = registered
499                    .iter()
500                    .map(|data_id| {
501                        C::encode(crate::TxnsEntry::Register(*data_id, T::encode(&forget_ts)))
502                    })
503                    .collect::<Vec<_>>();
504                let updates = updates
505                    .iter()
506                    .map(|(key, val)| ((key, val), &forget_ts, -1))
507                    .collect::<Vec<_>>();
508
509                // If the txns_upper has passed forget_ts, we can no longer write.
510                if forget_ts < txns_upper {
511                    debug!(
512                        "txns forget_all {} at {:?} mismatch current={:?}",
513                        data_ids_debug(),
514                        forget_ts,
515                        txns_upper,
516                    );
517                    return Err(txns_upper);
518                }
519
520                // Ensure the latest write has been applied, so we don't run into
521                // any issues trying to apply it later.
522                //
523                // NB: It's _very_ important for correctness to get this from the
524                // unapplied batches (which compact themselves naturally) and not
525                // from the writes (which are artificially compacted based on when
526                // we need reads for).
527                let data_latest_unapplied = self.txns_cache.unapplied_batches.values().last();
528                if let Some((_, _, latest_write)) = data_latest_unapplied {
529                    debug!(
530                        "txns forget_all {} applying latest write {:?}",
531                        data_ids_debug(),
532                        latest_write,
533                    );
534                    let latest_write = latest_write.clone();
535                    let _tidy = self.apply_le(&latest_write).await;
536                }
537                let res = crate::small_caa(
538                    || format!("txns forget_all {}", data_ids_debug()),
539                    &mut self.txns_write,
540                    &updates,
541                    txns_upper,
542                    forget_ts.step_forward(),
543                )
544                .await;
545                match res {
546                    Ok(()) => {
547                        debug!(
548                            "txns forget_all {} at {:?} success",
549                            data_ids_debug(),
550                            forget_ts
551                        );
552                        break registered;
553                    }
554                    Err(new_txns_upper) => {
555                        self.metrics.forget_all.retry_count.inc();
556                        txns_upper = new_txns_upper;
557                        continue;
558                    }
559                }
560            };
561
562            for data_id in registered.iter() {
563                self.datas.data_write_for_commit.remove(data_id);
564            }
565            let tidy = self.apply_le(&forget_ts).await;
566
567            Ok((registered, tidy))
568        })
569        .await
570    }
571
572    /// "Applies" all committed txns <= the given timestamp, ensuring that reads
573    /// at that timestamp will not block.
574    ///
575    /// In the common case, the txn committer will have done this work and this
576    /// method will be a no-op, but it is not guaranteed. In the event of a
577    /// crash or race, this does whatever persist writes are necessary (and
578    /// returns the resulting maintenance work), which could be significant.
579    ///
580    /// If the requested timestamp has not yet been written, this could block
581    /// for an unbounded amount of time.
582    ///
583    /// This method is idempotent.
584    #[instrument(level = "debug", fields(ts = ?ts))]
585    pub async fn apply_le(&mut self, ts: &T) -> Tidy {
586        let op = &self.metrics.apply_le;
587        op.run(async {
588            debug!("apply_le {:?}", ts);
589            let _ = self.txns_cache.update_gt(ts).await;
590            self.txns_cache.update_gauges(&self.metrics);
591
592            let mut unapplied_by_data = BTreeMap::<_, Vec<_>>::new();
593            for (data_id, unapplied, unapplied_ts) in self.txns_cache.unapplied() {
594                if ts < unapplied_ts {
595                    break;
596                }
597                unapplied_by_data
598                    .entry(*data_id)
599                    .or_default()
600                    .push((unapplied, unapplied_ts));
601            }
602
603            let retractions = FuturesUnordered::new();
604            for (data_id, unapplied) in unapplied_by_data {
605                let mut data_write = self.datas.take_write_for_apply(&data_id).await;
606                retractions.push(async move {
607                    let mut ret = Vec::new();
608                    for (unapplied, unapplied_ts) in unapplied {
609                        match unapplied {
610                            Unapplied::RegisterForget => {
611                                let () = crate::empty_caa(
612                                    || {
613                                        format!(
614                                            "data {:.9} register/forget fill",
615                                            data_id.to_string()
616                                        )
617                                    },
618                                    &mut data_write,
619                                    unapplied_ts.clone(),
620                                )
621                                .await;
622                            }
623                            Unapplied::Batch(batch_raws) => {
624                                let batch_raws = batch_raws
625                                    .into_iter()
626                                    .map(|batch_raw| batch_raw.as_slice())
627                                    .collect();
628                                crate::apply_caa(
629                                    &mut data_write,
630                                    &batch_raws,
631                                    unapplied_ts.clone(),
632                                )
633                                .await;
634                                for batch_raw in batch_raws {
635                                    // NB: Protos are not guaranteed to exactly roundtrip the
636                                    // encoded bytes, so we intentionally use the raw batch so that
637                                    // it definitely retracts.
638                                    ret.push((
639                                        batch_raw.to_vec(),
640                                        (T::encode(unapplied_ts), data_id),
641                                    ));
642                                }
643                            }
644                        }
645                    }
646                    (data_write, ret)
647                });
648            }
649            let retractions = retractions.collect::<Vec<_>>().await;
650            let retractions = retractions
651                .into_iter()
652                .flat_map(|(data_write, retractions)| {
653                    self.datas.put_write_for_apply(data_write);
654                    retractions
655                })
656                .collect();
657
658            // Remove all the applied registers.
659            self.txns_cache.mark_register_applied(ts);
660
661            debug!("apply_le {:?} success", ts);
662            Tidy { retractions }
663        })
664        .await
665    }
666
667    /// Commits the tidy work at the given time.
668    ///
669    /// Mostly a helper to make it obvious that we can throw away the apply work
670    /// (and not get into an infinite cycle of tidy->apply->tidy).
671    #[cfg(test)]
672    pub async fn tidy_at(&mut self, tidy_ts: T, tidy: Tidy) -> Result<(), T> {
673        debug!("tidy at {:?}", tidy_ts);
674
675        let mut txn = self.begin();
676        txn.tidy(tidy);
677        // We just constructed this txn, so it couldn't have committed any
678        // batches, and thus there's nothing to apply. We're free to throw it
679        // away.
680        let apply = txn.commit_at(self, tidy_ts.clone()).await?;
681        assert!(apply.is_empty());
682
683        debug!("tidy at {:?} success", tidy_ts);
684        Ok(())
685    }
686
687    /// Allows compaction to the txns shard as well as internal representations,
688    /// losing the ability to answer queries about times less_than since_ts.
689    ///
690    /// In practice, this will likely only be called from the singleton
691    /// controller process.
692    pub async fn compact_to(&mut self, mut since_ts: T) {
693        let op = &self.metrics.compact_to;
694        op.run(async {
695            tracing::debug!("compact_to {:?}", since_ts);
696            let _ = self.txns_cache.update_gt(&since_ts).await;
697
698            // NB: A critical invariant for how this all works is that we never
699            // allow the since of the txns shard to pass any unapplied writes, so
700            // reduce it as necessary.
701            let min_unapplied_ts = self.txns_cache.min_unapplied_ts();
702            if min_unapplied_ts < &since_ts {
703                since_ts.clone_from(min_unapplied_ts);
704            }
705            crate::cads::<T, O, C>(&mut self.txns_since, since_ts).await;
706        })
707        .await
708    }
709
710    /// Upgrade the version on the backing shard.
711    ///
712    /// In practice, this will likely only be called from the singleton
713    /// controller process.
714    pub async fn upgrade_version(&mut self) {
715        self.txns_since
716            .upgrade_version()
717            .await
718            .expect("invalid usage")
719    }
720
721    /// Returns the [ShardId] of the txns shard.
722    pub fn txns_id(&self) -> ShardId {
723        self.txns_write.shard_id()
724    }
725
726    /// Returns the [TxnsCache] used by this handle.
727    pub fn read_cache(&self) -> &TxnsCache<T, C> {
728        &self.txns_cache
729    }
730}
731
732/// A token representing maintenance writes (in particular, retractions) to the
733/// txns shard.
734///
735/// This can be written on its own with `TxnsHandle::tidy_at` or sidecar'd into
736/// a normal txn with [Txn::tidy].
737#[derive(Debug, Default)]
738pub struct Tidy {
739    pub(crate) retractions: BTreeMap<Vec<u8>, ([u8; 8], ShardId)>,
740}
741
742impl Tidy {
743    /// Merges the work represented by the other tidy into this one.
744    pub fn merge(&mut self, other: Tidy) {
745        self.retractions.extend(other.retractions)
746    }
747}
748
749/// A helper to make a more targeted mutable borrow of self.
750#[derive(Debug)]
751pub(crate) struct DataHandles<K: Codec, V: Codec, T, D> {
752    pub(crate) dyncfgs: ConfigSet,
753    pub(crate) client: Arc<PersistClient>,
754    /// See [DataWriteApply].
755    ///
756    /// This is lazily populated with the set of shards touched by `apply_le`.
757    data_write_for_apply: BTreeMap<ShardId, DataWriteApply<K, V, T, D>>,
758    /// See [DataWriteCommit].
759    ///
760    /// This contains the set of data shards registered but not yet forgotten
761    /// with this particular write handle.
762    ///
763    /// NB: In the common case, this and `_for_apply` will contain the same set
764    /// of shards, but this is not required. A shard can be in either and not
765    /// the other.
766    data_write_for_commit: BTreeMap<ShardId, DataWriteCommit<K, V, T, D>>,
767}
768
769impl<K, V, T, D> DataHandles<K, V, T, D>
770where
771    K: Debug + Codec,
772    V: Debug + Codec,
773    T: Timestamp + Lattice + TotalOrder + Codec64 + Sync,
774    D: Monoid + Ord + Codec64 + Send + Sync,
775{
776    async fn open_data_write_for_apply(&self, data_id: ShardId) -> DataWriteApply<K, V, T, D> {
777        let diagnostics = Diagnostics::from_purpose("txn data");
778        let schemas = self
779            .client
780            .latest_schema::<K, V, T, D>(data_id, diagnostics.clone())
781            .await
782            .expect("codecs have not changed");
783        let (key_schema, val_schema) = match schemas {
784            Some((_, key_schema, val_schema)) => (Arc::new(key_schema), Arc::new(val_schema)),
785            // We will always have at least one schema registered by the time we reach this point,
786            // because that is ensured at txn-registration time.
787            None => unreachable!("data shard {} should have a schema", data_id),
788        };
789        let wrapped = self
790            .client
791            .open_writer(data_id, key_schema, val_schema, diagnostics)
792            .await
793            .expect("schema shouldn't change");
794        DataWriteApply {
795            apply_ensure_schema_match: APPLY_ENSURE_SCHEMA_MATCH.handle(&self.dyncfgs),
796            client: Arc::clone(&self.client),
797            wrapped,
798        }
799    }
800
801    pub(crate) async fn take_write_for_apply(
802        &mut self,
803        data_id: &ShardId,
804    ) -> DataWriteApply<K, V, T, D> {
805        if let Some(data_write) = self.data_write_for_apply.remove(data_id) {
806            return data_write;
807        }
808        self.open_data_write_for_apply(*data_id).await
809    }
810
811    pub(crate) fn put_write_for_apply(&mut self, data_write: DataWriteApply<K, V, T, D>) {
812        self.data_write_for_apply
813            .insert(data_write.shard_id(), data_write);
814    }
815
816    pub(crate) fn take_write_for_commit(
817        &mut self,
818        data_id: &ShardId,
819    ) -> Option<DataWriteCommit<K, V, T, D>> {
820        self.data_write_for_commit.remove(data_id)
821    }
822
823    pub(crate) fn put_write_for_commit(&mut self, data_write: DataWriteCommit<K, V, T, D>) {
824        let prev = self
825            .data_write_for_commit
826            .insert(data_write.shard_id(), data_write);
827        assert!(prev.is_none());
828    }
829}
830
831/// A newtype wrapper around [WriteHandle] indicating that it has a real schema
832/// registered by the user.
833///
834/// The txn-wal user declares which schema they'd like to use for committing
835/// batches by passing it in (as part of the WriteHandle) in the call to
836/// register. This must be used to encode any new batches written. The wrapper
837/// helps us from accidentally mixing up the WriteHandles that we internally
838/// invent for applying the batches (which use a schema matching the one
839/// declared in the batch).
840#[derive(Debug)]
841pub(crate) struct DataWriteCommit<K: Codec, V: Codec, T, D>(pub(crate) WriteHandle<K, V, T, D>);
842
843impl<K: Codec, V: Codec, T, D> Deref for DataWriteCommit<K, V, T, D> {
844    type Target = WriteHandle<K, V, T, D>;
845
846    fn deref(&self) -> &Self::Target {
847        &self.0
848    }
849}
850
851impl<K: Codec, V: Codec, T, D> DerefMut for DataWriteCommit<K, V, T, D> {
852    fn deref_mut(&mut self) -> &mut Self::Target {
853        &mut self.0
854    }
855}
856
857/// A newtype wrapper around [WriteHandle] indicating that it can alter the
858/// schema its using to match the one in the batches being appended.
859///
860/// When a batch is committed to txn-wal, it contains metadata about which
861/// schemas were used to encode the data in it. Txn-wal then uses this info to
862/// make sure that in [TxnsHandle::apply_le], that the `compare_and_append` call
863/// happens on a handle with the same schema. This is accomplished by querying
864/// the persist schema registry.
865#[derive(Debug)]
866pub(crate) struct DataWriteApply<K: Codec, V: Codec, T, D> {
867    client: Arc<PersistClient>,
868    apply_ensure_schema_match: ConfigValHandle<bool>,
869    pub(crate) wrapped: WriteHandle<K, V, T, D>,
870}
871
872impl<K: Codec, V: Codec, T, D> Deref for DataWriteApply<K, V, T, D> {
873    type Target = WriteHandle<K, V, T, D>;
874
875    fn deref(&self) -> &Self::Target {
876        &self.wrapped
877    }
878}
879
880impl<K: Codec, V: Codec, T, D> DerefMut for DataWriteApply<K, V, T, D> {
881    fn deref_mut(&mut self) -> &mut Self::Target {
882        &mut self.wrapped
883    }
884}
885
886pub(crate) const APPLY_ENSURE_SCHEMA_MATCH: Config<bool> = Config::new(
887    "txn_wal_apply_ensure_schema_match",
888    true,
889    "CYA to skip updating write handle to batch schema in apply",
890);
891
892fn at_most_one_schema(
893    schemas: impl Iterator<Item = SchemaId>,
894) -> Result<Option<SchemaId>, (SchemaId, SchemaId)> {
895    let mut schema = None;
896    for s in schemas {
897        match schema {
898            None => schema = Some(s),
899            Some(x) if s != x => return Err((s, x)),
900            Some(_) => continue,
901        }
902    }
903    Ok(schema)
904}
905
906impl<K, V, T, D> DataWriteApply<K, V, T, D>
907where
908    K: Debug + Codec,
909    V: Debug + Codec,
910    T: Timestamp + Lattice + TotalOrder + Codec64 + Sync,
911    D: Monoid + Ord + Codec64 + Send + Sync,
912{
913    pub(crate) async fn maybe_replace_with_batch_schema(&mut self, batches: &[Batch<K, V, T, D>]) {
914        // TODO: Remove this once everything is rolled out and we're sure it's
915        // not going to cause any issues.
916        if !self.apply_ensure_schema_match.get() {
917            return;
918        }
919        let batch_schema = at_most_one_schema(batches.iter().flat_map(|x| x.schemas()));
920        let batch_schema = batch_schema.unwrap_or_else(|_| {
921            panic!(
922                "txn-wal uses at most one schema to commit batches, got: {:?}",
923                batches.iter().flat_map(|x| x.schemas()).collect::<Vec<_>>()
924            )
925        });
926        let (batch_schema, handle_schema) = match (batch_schema, self.wrapped.schema_id()) {
927            (Some(batch_schema), Some(handle_schema)) if batch_schema != handle_schema => {
928                (batch_schema, handle_schema)
929            }
930            _ => return,
931        };
932
933        let data_id = self.shard_id();
934        let diagnostics = Diagnostics::from_purpose("txn data");
935        let (key_schema, val_schema) = self
936            .client
937            .get_schema::<K, V, T, D>(data_id, batch_schema, diagnostics.clone())
938            .await
939            .expect("codecs shouldn't change")
940            .expect("id must have been registered to create this batch");
941        let new_data_write = self
942            .client
943            .open_writer(
944                self.shard_id(),
945                Arc::new(key_schema),
946                Arc::new(val_schema),
947                diagnostics,
948            )
949            .await
950            .expect("codecs shouldn't change");
951        tracing::info!(
952            "updated {} write handle from {} to {} to apply batches",
953            data_id,
954            handle_schema,
955            batch_schema
956        );
957        assert_eq!(new_data_write.schema_id(), Some(batch_schema));
958        self.wrapped = new_data_write;
959    }
960}
961
962#[cfg(test)]
963mod tests {
964    use std::time::{Duration, UNIX_EPOCH};
965
966    use differential_dataflow::Hashable;
967    use futures::future::BoxFuture;
968    use mz_ore::assert_none;
969    use mz_ore::cast::CastFrom;
970    use mz_ore::collections::CollectionExt;
971    use mz_ore::metrics::MetricsRegistry;
972    use mz_persist_client::PersistLocation;
973    use mz_persist_client::cache::PersistClientCache;
974    use mz_persist_client::cfg::RetryParameters;
975    use rand::rngs::SmallRng;
976    use rand::{RngCore, SeedableRng};
977    use timely::progress::Antichain;
978    use tokio::sync::oneshot;
979    use tracing::{Instrument, info, info_span};
980
981    use crate::operator::DataSubscribe;
982    use crate::tests::{CommitLog, reader, write_directly, writer};
983
984    use super::*;
985
986    impl TxnsHandle<String, (), u64, i64, u64, TxnsCodecDefault> {
987        pub(crate) async fn expect_open(client: PersistClient) -> Self {
988            Self::expect_open_id(client, ShardId::new()).await
989        }
990
991        pub(crate) async fn expect_open_id(client: PersistClient, txns_id: ShardId) -> Self {
992            let dyncfgs = crate::all_dyncfgs(client.dyncfgs().clone());
993            Self::open(
994                0,
995                client,
996                dyncfgs,
997                Arc::new(Metrics::new(&MetricsRegistry::new())),
998                txns_id,
999            )
1000            .await
1001        }
1002
1003        pub(crate) fn new_log(&self) -> CommitLog {
1004            CommitLog::new((*self.datas.client).clone(), self.txns_id())
1005        }
1006
1007        pub(crate) async fn expect_register(&mut self, register_ts: u64) -> ShardId {
1008            self.expect_registers(register_ts, 1).await.into_element()
1009        }
1010
1011        pub(crate) async fn expect_registers(
1012            &mut self,
1013            register_ts: u64,
1014            amount: usize,
1015        ) -> Vec<ShardId> {
1016            let data_ids: Vec<_> = (0..amount).map(|_| ShardId::new()).collect();
1017            let mut writers = Vec::new();
1018            for data_id in &data_ids {
1019                writers.push(writer(&self.datas.client, *data_id).await);
1020            }
1021            self.register(register_ts, writers).await.unwrap();
1022            data_ids
1023        }
1024
1025        pub(crate) async fn expect_commit_at(
1026            &mut self,
1027            commit_ts: u64,
1028            data_id: ShardId,
1029            keys: &[&str],
1030            log: &CommitLog,
1031        ) -> Tidy {
1032            let mut txn = self.begin();
1033            for key in keys {
1034                txn.write(&data_id, (*key).into(), (), 1).await;
1035            }
1036            let tidy = txn
1037                .commit_at(self, commit_ts)
1038                .await
1039                .unwrap()
1040                .apply(self)
1041                .await;
1042            for key in keys {
1043                log.record((data_id, (*key).into(), commit_ts, 1));
1044            }
1045            tidy
1046        }
1047    }
1048
1049    #[mz_ore::test(tokio::test)]
1050    #[cfg_attr(miri, ignore)] // too slow
1051    async fn register_at() {
1052        let client = PersistClient::new_for_tests().await;
1053        let mut txns = TxnsHandle::expect_open(client.clone()).await;
1054        let log = txns.new_log();
1055        let d0 = txns.expect_register(2).await;
1056
1057        // Register a second time is a no-op (idempotent).
1058        txns.register(3, [writer(&client, d0).await]).await.unwrap();
1059
1060        // Cannot register a new data shard at an already closed off time. An
1061        // error is returned with the first time that a registration would
1062        // succeed.
1063        let d1 = ShardId::new();
1064        assert_eq!(
1065            txns.register(2, [writer(&client, d1).await])
1066                .await
1067                .unwrap_err(),
1068            4
1069        );
1070
1071        // Can still register after txns have been committed.
1072        txns.expect_commit_at(4, d0, &["foo"], &log).await;
1073        txns.register(5, [writer(&client, d1).await]).await.unwrap();
1074
1075        // We can also register some new and some already registered shards.
1076        let d2 = ShardId::new();
1077        txns.register(6, [writer(&client, d0).await, writer(&client, d2).await])
1078            .await
1079            .unwrap();
1080
1081        let () = log.assert_snapshot(d0, 6).await;
1082        let () = log.assert_snapshot(d1, 6).await;
1083    }
1084
1085    /// A sanity check that CommitLog catches an incorrect usage (proxy for a
1086    /// bug that looks like an incorrect usage).
1087    #[mz_ore::test(tokio::test)]
1088    #[cfg_attr(miri, ignore)] // too slow
1089    #[should_panic(expected = "left: [(\"foo\", 2, 1)]\n right: [(\"foo\", 2, 2)]")]
1090    async fn incorrect_usage_register_write_same_time() {
1091        let client = PersistClient::new_for_tests().await;
1092        let mut txns = TxnsHandle::expect_open(client.clone()).await;
1093        let log = txns.new_log();
1094        let d0 = txns.expect_register(1).await;
1095        let mut d0_write = writer(&client, d0).await;
1096
1097        // Commit a write at ts 2...
1098        let mut txn = txns.begin_test();
1099        txn.write(&d0, "foo".into(), (), 1).await;
1100        let apply = txn.commit_at(&mut txns, 2).await.unwrap();
1101        log.record_txn(2, &txn);
1102        // ... and (incorrectly) also write to the shard normally at ts 2.
1103        let () = d0_write
1104            .compare_and_append(
1105                &[(("foo".to_owned(), ()), 2, 1)],
1106                Antichain::from_elem(2),
1107                Antichain::from_elem(3),
1108            )
1109            .await
1110            .unwrap()
1111            .unwrap();
1112        log.record((d0, "foo".into(), 2, 1));
1113        apply.apply(&mut txns).await;
1114
1115        // Verify that CommitLog catches this.
1116        log.assert_snapshot(d0, 2).await;
1117    }
1118
1119    #[mz_ore::test(tokio::test)]
1120    #[cfg_attr(miri, ignore)] // too slow
1121    async fn forget_at() {
1122        let client = PersistClient::new_for_tests().await;
1123        let mut txns = TxnsHandle::expect_open(client.clone()).await;
1124        let log = txns.new_log();
1125
1126        // Can forget a data_shard that has not been registered.
1127        txns.forget(1, [ShardId::new()]).await.unwrap();
1128
1129        // Can forget multiple data_shards that have not been registered.
1130        txns.forget(2, (0..5).map(|_| ShardId::new()))
1131            .await
1132            .unwrap();
1133
1134        // Can forget a registered shard.
1135        let d0 = txns.expect_register(3).await;
1136        txns.forget(4, [d0]).await.unwrap();
1137
1138        // Can forget multiple registered shards.
1139        let ds = txns.expect_registers(5, 5).await;
1140        txns.forget(6, ds.clone()).await.unwrap();
1141
1142        // Forget is idempotent.
1143        txns.forget(7, [d0]).await.unwrap();
1144        txns.forget(8, ds.clone()).await.unwrap();
1145
1146        // Cannot forget at an already closed off time. An error is returned
1147        // with the first time that a registration would succeed.
1148        let d1 = txns.expect_register(9).await;
1149        assert_eq!(txns.forget(9, [d1]).await.unwrap_err(), 10);
1150
1151        // Write to txns and to d0 directly.
1152        let mut d0_write = writer(&client, d0).await;
1153        txns.expect_commit_at(10, d1, &["d1"], &log).await;
1154        let updates = [(("d0".to_owned(), ()), 10, 1)];
1155        d0_write
1156            .compare_and_append(&updates, d0_write.shared_upper(), Antichain::from_elem(11))
1157            .await
1158            .unwrap()
1159            .unwrap();
1160        log.record((d0, "d0".into(), 10, 1));
1161
1162        // Can register and forget an already registered and forgotten shard.
1163        txns.register(11, [writer(&client, d0).await])
1164            .await
1165            .unwrap();
1166        let mut forget_expected = vec![d0, d1];
1167        forget_expected.sort();
1168        assert_eq!(txns.forget_all(12).await.unwrap().0, forget_expected);
1169
1170        // Close shard to writes
1171        d0_write
1172            .compare_and_append_batch(&mut [], d0_write.shared_upper(), Antichain::new(), true)
1173            .await
1174            .unwrap()
1175            .unwrap();
1176
1177        let () = log.assert_snapshot(d0, 12).await;
1178        let () = log.assert_snapshot(d1, 12).await;
1179
1180        for di in ds {
1181            let mut di_write = writer(&client, di).await;
1182
1183            // Close shards to writes
1184            di_write
1185                .compare_and_append_batch(&mut [], di_write.shared_upper(), Antichain::new(), true)
1186                .await
1187                .unwrap()
1188                .unwrap();
1189
1190            let () = log.assert_snapshot(di, 8).await;
1191        }
1192    }
1193
1194    #[mz_ore::test(tokio::test)]
1195    #[cfg_attr(miri, ignore)] // too slow
1196    async fn register_forget() {
1197        async fn step_some_past(subs: &mut Vec<DataSubscribe>, ts: u64) {
1198            for (idx, sub) in subs.iter_mut().enumerate() {
1199                // Only step some of them to try to maximize edge cases.
1200                if usize::cast_from(ts) % (idx + 1) == 0 {
1201                    async {
1202                        info!("stepping sub {} past {}", idx, ts);
1203                        sub.step_past(ts).await;
1204                    }
1205                    .instrument(info_span!("sub", idx))
1206                    .await;
1207                }
1208            }
1209        }
1210
1211        let client = PersistClient::new_for_tests().await;
1212        let mut txns = TxnsHandle::expect_open(client.clone()).await;
1213        let log = txns.new_log();
1214        let d0 = ShardId::new();
1215        let mut d0_write = writer(&client, d0).await;
1216        let mut subs = Vec::new();
1217
1218        // Loop for a while doing the following:
1219        // - Write directly to some time before register ts
1220        // - Register
1221        // - Write via txns
1222        // - Forget
1223        //
1224        // After each step, make sure that a subscription started at that time
1225        // works and that all subscriptions can be stepped through the expected
1226        // timestamp.
1227        let mut ts = 0;
1228        while ts < 32 {
1229            subs.push(txns.read_cache().expect_subscribe(&client, d0, ts));
1230            ts += 1;
1231            info!("{} direct", ts);
1232            txns.begin().commit_at(&mut txns, ts).await.unwrap();
1233            write_directly(ts, &mut d0_write, &[&format!("d{}", ts)], &log).await;
1234            step_some_past(&mut subs, ts).await;
1235            if ts % 11 == 0 {
1236                txns.compact_to(ts).await;
1237            }
1238
1239            subs.push(txns.read_cache().expect_subscribe(&client, d0, ts));
1240            ts += 1;
1241            info!("{} register", ts);
1242            txns.register(ts, [writer(&client, d0).await])
1243                .await
1244                .unwrap();
1245            step_some_past(&mut subs, ts).await;
1246            if ts % 11 == 0 {
1247                txns.compact_to(ts).await;
1248            }
1249
1250            subs.push(txns.read_cache().expect_subscribe(&client, d0, ts));
1251            ts += 1;
1252            info!("{} txns", ts);
1253            txns.expect_commit_at(ts, d0, &[&format!("t{}", ts)], &log)
1254                .await;
1255            step_some_past(&mut subs, ts).await;
1256            if ts % 11 == 0 {
1257                txns.compact_to(ts).await;
1258            }
1259
1260            subs.push(txns.read_cache().expect_subscribe(&client, d0, ts));
1261            ts += 1;
1262            info!("{} forget", ts);
1263            txns.forget(ts, [d0]).await.unwrap();
1264            step_some_past(&mut subs, ts).await;
1265            if ts % 11 == 0 {
1266                txns.compact_to(ts).await;
1267            }
1268        }
1269
1270        // Check all the subscribes.
1271        for mut sub in subs.into_iter() {
1272            sub.step_past(ts).await;
1273            log.assert_eq(d0, sub.as_of, sub.progress(), sub.output().clone());
1274        }
1275    }
1276
1277    // Regression test for a bug encountered during initial development:
1278    // - task 0 commits to a data shard at ts T
1279    // - before task 0 can unblock T for reads, task 1 tries to register it at
1280    //   time T (ditto T+X), this does a caa of empty space, advancing the upper
1281    //   of the data shard to T+1
1282    // - task 0 attempts to caa in the batch, but finds out the upper is T+1 and
1283    //   assumes someone else did the work
1284    // - result: the write is lost
1285    #[mz_ore::test(tokio::test)]
1286    #[cfg_attr(miri, ignore)] // unsupported operation: returning ready events from epoll_wait is not yet implemented
1287    async fn race_data_shard_register_and_commit() {
1288        let client = PersistClient::new_for_tests().await;
1289        let mut txns = TxnsHandle::expect_open(client.clone()).await;
1290        let d0 = txns.expect_register(1).await;
1291
1292        let mut txn = txns.begin();
1293        txn.write(&d0, "foo".into(), (), 1).await;
1294        let commit_apply = txn.commit_at(&mut txns, 2).await.unwrap();
1295
1296        txns.register(3, [writer(&client, d0).await]).await.unwrap();
1297
1298        // Make sure that we can read empty at the register commit time even
1299        // before the txn commit apply.
1300        let actual = txns.txns_cache.expect_snapshot(&client, d0, 1).await;
1301        assert_eq!(actual, Vec::<String>::new());
1302
1303        commit_apply.apply(&mut txns).await;
1304        let actual = txns.txns_cache.expect_snapshot(&client, d0, 2).await;
1305        assert_eq!(actual, vec!["foo".to_owned()]);
1306    }
1307
1308    // A test that applies a batch of writes all at once.
1309    #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
1310    #[cfg_attr(miri, ignore)] // unsupported operation: returning ready events from epoll_wait is not yet implemented
1311    async fn apply_many_ts() {
1312        let client = PersistClient::new_for_tests().await;
1313        let mut txns = TxnsHandle::expect_open(client.clone()).await;
1314        let log = txns.new_log();
1315        let d0 = txns.expect_register(1).await;
1316
1317        for ts in 2..10 {
1318            let mut txn = txns.begin();
1319            txn.write(&d0, ts.to_string(), (), 1).await;
1320            let _apply = txn.commit_at(&mut txns, ts).await.unwrap();
1321            log.record((d0, ts.to_string(), ts, 1));
1322        }
1323        // This automatically runs the apply, which catches up all the previous
1324        // txns at once.
1325        txns.expect_commit_at(10, d0, &[], &log).await;
1326
1327        log.assert_snapshot(d0, 10).await;
1328    }
1329
1330    struct StressWorker {
1331        idx: usize,
1332        data_ids: Vec<ShardId>,
1333        txns: TxnsHandle<String, (), u64, i64>,
1334        log: CommitLog,
1335        tidy: Tidy,
1336        ts: u64,
1337        step: usize,
1338        rng: SmallRng,
1339        reads: Vec<(
1340            oneshot::Sender<u64>,
1341            ShardId,
1342            u64,
1343            mz_ore::task::JoinHandle<Vec<(String, u64, i64)>>,
1344        )>,
1345    }
1346
1347    impl StressWorker {
1348        pub async fn step(&mut self) {
1349            debug!(
1350                "stress {} step {} START ts={}",
1351                self.idx, self.step, self.ts
1352            );
1353            let data_id =
1354                self.data_ids[usize::cast_from(self.rng.next_u64()) % self.data_ids.len()];
1355            match self.rng.next_u64() % 6 {
1356                0 => self.write(data_id).await,
1357                // The register and forget impls intentionally don't switch on
1358                // whether it's already registered to stress idempotence.
1359                1 => self.register(data_id).await,
1360                2 => self.forget(data_id).await,
1361                3 => {
1362                    debug!("stress update {:.9} to {}", data_id.to_string(), self.ts);
1363                    let _ = self.txns.txns_cache.update_ge(&self.ts).await;
1364                }
1365                4 => self.start_read(data_id),
1366                5 => self.start_read(data_id),
1367                _ => unreachable!(""),
1368            }
1369            debug!("stress {} step {} DONE ts={}", self.idx, self.step, self.ts);
1370            self.step += 1;
1371        }
1372
1373        fn key(&self) -> String {
1374            format!("w{}s{}", self.idx, self.step)
1375        }
1376
1377        async fn registered_at_progress_ts(&mut self, data_id: ShardId) -> bool {
1378            self.ts = *self.txns.txns_cache.update_ge(&self.ts).await;
1379            self.txns
1380                .txns_cache
1381                .registered_at_progress(&data_id, &self.ts)
1382        }
1383
1384        // Writes to the given data shard, either via txns if it's registered or
1385        // directly if it's not.
1386        async fn write(&mut self, data_id: ShardId) {
1387            // Make sure to keep the registered_at_ts call _inside_ the retry
1388            // loop, because a data shard might switch between registered or not
1389            // registered as the loop advances through timestamps.
1390            self.retry_ts_err(&mut |w: &mut StressWorker| {
1391                Box::pin(async move {
1392                    if w.registered_at_progress_ts(data_id).await {
1393                        w.write_via_txns(data_id).await
1394                    } else {
1395                        w.write_direct(data_id).await
1396                    }
1397                })
1398            })
1399            .await
1400        }
1401
1402        async fn write_via_txns(&mut self, data_id: ShardId) -> Result<(), u64> {
1403            debug!(
1404                "stress write_via_txns {:.9} at {}",
1405                data_id.to_string(),
1406                self.ts
1407            );
1408            // HACK: Normally, we'd make sure that this particular handle had
1409            // registered the data shard before writing to it, but that would
1410            // consume a ts and isn't quite how we want `write_via_txns` to
1411            // work. Work around that by setting a write handle (with a schema
1412            // that we promise is correct) in the right place.
1413            if !self.txns.datas.data_write_for_commit.contains_key(&data_id) {
1414                let x = writer(&self.txns.datas.client, data_id).await;
1415                self.txns
1416                    .datas
1417                    .data_write_for_commit
1418                    .insert(data_id, DataWriteCommit(x));
1419            }
1420            let mut txn = self.txns.begin_test();
1421            txn.tidy(std::mem::take(&mut self.tidy));
1422            txn.write(&data_id, self.key(), (), 1).await;
1423            let apply = txn.commit_at(&mut self.txns, self.ts).await?;
1424            debug!(
1425                "log {:.9} {} at {}",
1426                data_id.to_string(),
1427                self.key(),
1428                self.ts
1429            );
1430            self.log.record_txn(self.ts, &txn);
1431            if self.rng.next_u64() % 3 == 0 {
1432                self.tidy.merge(apply.apply(&mut self.txns).await);
1433            }
1434            Ok(())
1435        }
1436
1437        async fn write_direct(&mut self, data_id: ShardId) -> Result<(), u64> {
1438            debug!(
1439                "stress write_direct {:.9} at {}",
1440                data_id.to_string(),
1441                self.ts
1442            );
1443            // First write an empty txn to ensure that the shard isn't
1444            // registered at this ts by someone else.
1445            self.txns.begin().commit_at(&mut self.txns, self.ts).await?;
1446
1447            let mut write = writer(&self.txns.datas.client, data_id).await;
1448            let mut current = write.shared_upper().into_option().unwrap();
1449            loop {
1450                if !(current <= self.ts) {
1451                    return Err(current);
1452                }
1453                let key = self.key();
1454                let updates = [((&key, &()), &self.ts, 1)];
1455                let res = crate::small_caa(
1456                    || format!("data {:.9} direct", data_id.to_string()),
1457                    &mut write,
1458                    &updates,
1459                    current,
1460                    self.ts + 1,
1461                )
1462                .await;
1463                match res {
1464                    Ok(()) => {
1465                        debug!("log {:.9} {} at {}", data_id.to_string(), key, self.ts);
1466                        self.log.record((data_id, key, self.ts, 1));
1467                        return Ok(());
1468                    }
1469                    Err(new_current) => current = new_current,
1470                }
1471            }
1472        }
1473
1474        async fn register(&mut self, data_id: ShardId) {
1475            self.retry_ts_err(&mut |w: &mut StressWorker| {
1476                debug!("stress register {:.9} at {}", data_id.to_string(), w.ts);
1477                Box::pin(async move {
1478                    let data_write = writer(&w.txns.datas.client, data_id).await;
1479                    let _ = w.txns.register(w.ts, [data_write]).await?;
1480                    Ok(())
1481                })
1482            })
1483            .await
1484        }
1485
1486        async fn forget(&mut self, data_id: ShardId) {
1487            self.retry_ts_err(&mut |w: &mut StressWorker| {
1488                debug!("stress forget {:.9} at {}", data_id.to_string(), w.ts);
1489                Box::pin(async move { w.txns.forget(w.ts, [data_id]).await.map(|_| ()) })
1490            })
1491            .await
1492        }
1493
1494        fn start_read(&mut self, data_id: ShardId) {
1495            debug!(
1496                "stress start_read {:.9} at {}",
1497                data_id.to_string(),
1498                self.ts
1499            );
1500            let client = (*self.txns.datas.client).clone();
1501            let txns_id = self.txns.txns_id();
1502            let as_of = self.ts;
1503            debug!("start_read {:.9} as_of {}", data_id.to_string(), as_of);
1504            let (tx, mut rx) = oneshot::channel();
1505            let subscribe = mz_ore::task::spawn_blocking(
1506                || format!("{:.9}-{}", data_id.to_string(), as_of),
1507                move || {
1508                    let mut subscribe = DataSubscribe::new(
1509                        "test",
1510                        client,
1511                        txns_id,
1512                        data_id,
1513                        as_of,
1514                        Antichain::new(),
1515                    );
1516                    let data_id = format!("{:.9}", data_id.to_string());
1517                    let _guard = info_span!("read_worker", %data_id, as_of).entered();
1518                    loop {
1519                        subscribe.worker.step_or_park(None);
1520                        subscribe.capture_output();
1521                        let until = match rx.try_recv() {
1522                            Ok(ts) => ts,
1523                            Err(oneshot::error::TryRecvError::Empty) => {
1524                                continue;
1525                            }
1526                            Err(oneshot::error::TryRecvError::Closed) => 0,
1527                        };
1528                        while subscribe.progress() < until {
1529                            subscribe.worker.step_or_park(None);
1530                            subscribe.capture_output();
1531                        }
1532                        return subscribe.output().clone();
1533                    }
1534                },
1535            );
1536            self.reads.push((tx, data_id, as_of, subscribe));
1537        }
1538
1539        async fn retry_ts_err<W>(&mut self, work_fn: &mut W)
1540        where
1541            W: for<'b> FnMut(&'b mut Self) -> BoxFuture<'b, Result<(), u64>>,
1542        {
1543            loop {
1544                match work_fn(self).await {
1545                    Ok(ret) => return ret,
1546                    Err(new_ts) => self.ts = new_ts,
1547                }
1548            }
1549        }
1550    }
1551
1552    #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
1553    #[cfg_attr(miri, ignore)] // unsupported operation: returning ready events from epoll_wait is not yet implemented
1554    async fn stress_correctness() {
1555        const NUM_DATA_SHARDS: usize = 2;
1556        const NUM_WORKERS: usize = 2;
1557        const NUM_STEPS_PER_WORKER: usize = 100;
1558        let seed = UNIX_EPOCH.elapsed().unwrap().hashed();
1559        eprintln!("using seed {}", seed);
1560
1561        let mut clients = PersistClientCache::new_no_metrics();
1562        // We disable pubsub below, so retune the listen retries (pubsub
1563        // fallback) to keep the test speedy.
1564        clients
1565            .cfg()
1566            .set_next_listen_batch_retryer(RetryParameters {
1567                fixed_sleep: Duration::ZERO,
1568                initial_backoff: Duration::from_millis(1),
1569                multiplier: 1,
1570                clamp: Duration::from_millis(1),
1571            });
1572        let client = clients.open(PersistLocation::new_in_mem()).await.unwrap();
1573        let mut txns = TxnsHandle::expect_open(client.clone()).await;
1574        let log = txns.new_log();
1575        let data_ids = (0..NUM_DATA_SHARDS)
1576            .map(|_| ShardId::new())
1577            .collect::<Vec<_>>();
1578        let data_writes = data_ids
1579            .iter()
1580            .map(|data_id| writer(&client, *data_id))
1581            .collect::<FuturesUnordered<_>>()
1582            .collect::<Vec<_>>()
1583            .await;
1584        let _data_sinces = data_ids
1585            .iter()
1586            .map(|data_id| reader(&client, *data_id))
1587            .collect::<FuturesUnordered<_>>()
1588            .collect::<Vec<_>>()
1589            .await;
1590        let register_ts = 1;
1591        txns.register(register_ts, data_writes).await.unwrap();
1592
1593        let mut workers = Vec::new();
1594        for idx in 0..NUM_WORKERS {
1595            // Clear the state cache between each client to maximally disconnect
1596            // them from each other.
1597            clients.clear_state_cache();
1598            let client = clients.open(PersistLocation::new_in_mem()).await.unwrap();
1599            let mut worker = StressWorker {
1600                idx,
1601                log: log.clone(),
1602                txns: TxnsHandle::expect_open_id(client.clone(), txns.txns_id()).await,
1603                data_ids: data_ids.clone(),
1604                tidy: Tidy::default(),
1605                ts: register_ts,
1606                step: 0,
1607                rng: SmallRng::seed_from_u64(seed.wrapping_add(u64::cast_from(idx))),
1608                reads: Vec::new(),
1609            };
1610            let worker = async move {
1611                while worker.step < NUM_STEPS_PER_WORKER {
1612                    worker.step().await;
1613                }
1614                (worker.ts, worker.reads)
1615            }
1616            .instrument(info_span!("stress_worker", idx));
1617            workers.push(mz_ore::task::spawn(|| format!("worker-{}", idx), worker));
1618        }
1619
1620        let mut max_ts = 0;
1621        let mut reads = Vec::new();
1622        for worker in workers {
1623            let (t, mut r) = worker.await;
1624            max_ts = std::cmp::max(max_ts, t);
1625            reads.append(&mut r);
1626        }
1627
1628        // Run all of the following in a timeout to make hangs easier to debug.
1629        tokio::time::timeout(Duration::from_secs(30), async {
1630            info!("finished with max_ts of {}", max_ts);
1631            txns.apply_le(&max_ts).await;
1632            for data_id in data_ids {
1633                info!("reading data shard {}", data_id);
1634                log.assert_snapshot(data_id, max_ts)
1635                    .instrument(info_span!("read_data", data_id = format!("{:.9}", data_id)))
1636                    .await;
1637            }
1638            info!("now waiting for reads {}", max_ts);
1639            for (tx, data_id, as_of, subscribe) in reads {
1640                let _ = tx.send(max_ts + 1);
1641                let output = subscribe.await;
1642                log.assert_eq(data_id, as_of, max_ts + 1, output);
1643            }
1644        })
1645        .await
1646        .unwrap();
1647    }
1648
1649    #[mz_ore::test(tokio::test)]
1650    #[cfg_attr(miri, ignore)] // unsupported operation: returning ready events from epoll_wait is not yet implemented
1651    async fn advance_physical_uppers_past() {
1652        let client = PersistClient::new_for_tests().await;
1653        let mut txns = TxnsHandle::expect_open(client.clone()).await;
1654        let log = txns.new_log();
1655        let d0 = txns.expect_register(1).await;
1656        let mut d0_write = writer(&client, d0).await;
1657        let d1 = txns.expect_register(2).await;
1658        let mut d1_write = writer(&client, d1).await;
1659
1660        assert_eq!(d0_write.fetch_recent_upper().await.elements(), &[2]);
1661        assert_eq!(d1_write.fetch_recent_upper().await.elements(), &[3]);
1662
1663        // Normal `apply` (used by expect_commit_at) does not advance the
1664        // physical upper of data shards that were not involved in the txn (lazy
1665        // upper). d1 is not involved in this txn so stays where it is.
1666        txns.expect_commit_at(3, d0, &["0-2"], &log).await;
1667        assert_eq!(d0_write.fetch_recent_upper().await.elements(), &[4]);
1668        assert_eq!(d1_write.fetch_recent_upper().await.elements(), &[3]);
1669
1670        // d0 is not involved in this txn so stays where it is.
1671        txns.expect_commit_at(4, d1, &["1-3"], &log).await;
1672        assert_eq!(d0_write.fetch_recent_upper().await.elements(), &[4]);
1673        assert_eq!(d1_write.fetch_recent_upper().await.elements(), &[5]);
1674
1675        log.assert_snapshot(d0, 4).await;
1676        log.assert_snapshot(d1, 4).await;
1677    }
1678
1679    #[mz_ore::test(tokio::test)]
1680    #[cfg_attr(miri, ignore)]
1681    #[allow(clippy::unnecessary_get_then_check)] // Makes it less readable.
1682    async fn schemas() {
1683        let client = PersistClient::new_for_tests().await;
1684        let mut txns0 = TxnsHandle::expect_open(client.clone()).await;
1685        let mut txns1 = TxnsHandle::expect_open_id(client.clone(), txns0.txns_id()).await;
1686        let log = txns0.new_log();
1687        let d0 = txns0.expect_register(1).await;
1688
1689        // The register call happened on txns0, which means it has a real schema
1690        // and can commit batches.
1691        assert!(txns0.datas.data_write_for_commit.get(&d0).is_some());
1692        let mut txn = txns0.begin_test();
1693        txn.write(&d0, "foo".into(), (), 1).await;
1694        let apply = txn.commit_at(&mut txns0, 2).await.unwrap();
1695        log.record_txn(2, &txn);
1696
1697        // We can use handle without a register call to apply a committed txn.
1698        assert!(txns1.datas.data_write_for_commit.get(&d0).is_none());
1699        let _tidy = apply.apply(&mut txns1).await;
1700
1701        // However, it cannot commit batches.
1702        assert!(txns1.datas.data_write_for_commit.get(&d0).is_none());
1703        let res = mz_ore::task::spawn(|| "test", async move {
1704            let mut txn = txns1.begin();
1705            txn.write(&d0, "bar".into(), (), 1).await;
1706            // This panics.
1707            let _ = txn.commit_at(&mut txns1, 3).await;
1708        })
1709        .into_tokio_handle();
1710        assert!(res.await.is_err());
1711
1712        // Forgetting the data shard removes it, so we don't leave the schema
1713        // sitting around.
1714        assert!(txns0.datas.data_write_for_commit.get(&d0).is_some());
1715        txns0.forget(3, [d0]).await.unwrap();
1716        assert_none!(txns0.datas.data_write_for_commit.get(&d0));
1717
1718        // Forget is idempotent.
1719        assert_none!(txns0.datas.data_write_for_commit.get(&d0));
1720        txns0.forget(4, [d0]).await.unwrap();
1721        assert_none!(txns0.datas.data_write_for_commit.get(&d0));
1722
1723        // We can register it again and commit again.
1724        assert_none!(txns0.datas.data_write_for_commit.get(&d0));
1725        txns0
1726            .register(5, [writer(&client, d0).await])
1727            .await
1728            .unwrap();
1729        assert!(txns0.datas.data_write_for_commit.get(&d0).is_some());
1730        txns0.expect_commit_at(6, d0, &["baz"], &log).await;
1731
1732        log.assert_snapshot(d0, 6).await;
1733    }
1734}