1use std::borrow::Borrow;
13use std::fmt::Debug;
14use std::sync::Arc;
15
16use differential_dataflow::difference::Semigroup;
17use differential_dataflow::lattice::Lattice;
18use differential_dataflow::trace::Description;
19use futures::StreamExt;
20use futures::stream::FuturesUnordered;
21use mz_dyncfg::Config;
22use mz_ore::instrument;
23use mz_ore::task::RuntimeExt;
24use mz_persist::location::Blob;
25use mz_persist_types::schema::SchemaId;
26use mz_persist_types::{Codec, Codec64};
27use mz_proto::{IntoRustIfSome, ProtoType};
28use proptest_derive::Arbitrary;
29use semver::Version;
30use serde::{Deserialize, Serialize};
31use timely::PartialOrder;
32use timely::order::TotalOrder;
33use timely::progress::{Antichain, Timestamp};
34use tokio::runtime::Handle;
35use tracing::{Instrument, debug_span, info, warn};
36use uuid::Uuid;
37
38use crate::batch::{
39    Added, BATCH_DELETE_ENABLED, Batch, BatchBuilder, BatchBuilderConfig, BatchBuilderInternal,
40    BatchParts, ProtoBatch, validate_truncate_batch,
41};
42use crate::error::{InvalidUsage, UpperMismatch};
43use crate::fetch::{
44    EncodedPart, FetchBatchFilter, FetchedPart, PartDecodeFormat, VALIDATE_PART_BOUNDS_ON_READ,
45};
46use crate::internal::compact::{CompactConfig, Compactor};
47use crate::internal::encoding::{Schemas, check_data_version};
48use crate::internal::machine::{CompareAndAppendRes, ExpireFn, Machine};
49use crate::internal::metrics::{BatchWriteMetrics, Metrics, ShardMetrics};
50use crate::internal::state::{BatchPart, HandleDebugState, HollowBatch, RunOrder, RunPart};
51use crate::read::ReadHandle;
52use crate::schema::PartMigration;
53use crate::{GarbageCollector, IsolatedRuntime, PersistConfig, ShardId, parse_id};
54
55pub(crate) const COMBINE_INLINE_WRITES: Config<bool> = Config::new(
56    "persist_write_combine_inline_writes",
57    true,
58    "If set, re-encode inline writes if they don't fit into the batch metadata limits.",
59);
60
61pub(crate) const VALIDATE_PART_BOUNDS_ON_WRITE: Config<bool> = Config::new(
62    "persist_validate_part_bounds_on_write",
63    true,
64    "Validate the part lower <= the batch lower and the part upper <= batch upper,\
65    for the batch being appended.",
66);
67
68#[derive(Arbitrary, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
70#[serde(try_from = "String", into = "String")]
71pub struct WriterId(pub(crate) [u8; 16]);
72
73impl std::fmt::Display for WriterId {
74    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75        write!(f, "w{}", Uuid::from_bytes(self.0))
76    }
77}
78
79impl std::fmt::Debug for WriterId {
80    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81        write!(f, "WriterId({})", Uuid::from_bytes(self.0))
82    }
83}
84
85impl std::str::FromStr for WriterId {
86    type Err = String;
87
88    fn from_str(s: &str) -> Result<Self, Self::Err> {
89        parse_id("w", "WriterId", s).map(WriterId)
90    }
91}
92
93impl From<WriterId> for String {
94    fn from(writer_id: WriterId) -> Self {
95        writer_id.to_string()
96    }
97}
98
99impl TryFrom<String> for WriterId {
100    type Error = String;
101
102    fn try_from(s: String) -> Result<Self, Self::Error> {
103        s.parse()
104    }
105}
106
107impl WriterId {
108    pub(crate) fn new() -> Self {
109        WriterId(*Uuid::new_v4().as_bytes())
110    }
111}
112
113#[derive(Debug)]
129pub struct WriteHandle<K: Codec, V: Codec, T, D> {
130    pub(crate) cfg: PersistConfig,
131    pub(crate) metrics: Arc<Metrics>,
132    pub(crate) machine: Machine<K, V, T, D>,
133    pub(crate) gc: GarbageCollector<K, V, T, D>,
134    pub(crate) compact: Option<Compactor<K, V, T, D>>,
135    pub(crate) blob: Arc<dyn Blob>,
136    pub(crate) isolated_runtime: Arc<IsolatedRuntime>,
137    pub(crate) writer_id: WriterId,
138    pub(crate) debug_state: HandleDebugState,
139    pub(crate) write_schemas: Schemas<K, V>,
140
141    pub(crate) upper: Antichain<T>,
142    expire_fn: Option<ExpireFn>,
143}
144
145impl<K, V, T, D> WriteHandle<K, V, T, D>
146where
147    K: Debug + Codec,
148    V: Debug + Codec,
149    T: Timestamp + TotalOrder + Lattice + Codec64 + Sync,
150    D: Semigroup + Ord + Codec64 + Send + Sync,
151{
152    pub(crate) fn new(
153        cfg: PersistConfig,
154        metrics: Arc<Metrics>,
155        machine: Machine<K, V, T, D>,
156        gc: GarbageCollector<K, V, T, D>,
157        blob: Arc<dyn Blob>,
158        writer_id: WriterId,
159        purpose: &str,
160        write_schemas: Schemas<K, V>,
161    ) -> Self {
162        let isolated_runtime = Arc::clone(&machine.isolated_runtime);
163        let compact = cfg.compaction_enabled.then(|| {
164            Compactor::new(
165                cfg.clone(),
166                Arc::clone(&metrics),
167                write_schemas.clone(),
168                gc.clone(),
169            )
170        });
171        let debug_state = HandleDebugState {
172            hostname: cfg.hostname.to_owned(),
173            purpose: purpose.to_owned(),
174        };
175        let upper = machine.applier.clone_upper();
176        let expire_fn = Self::expire_fn(machine.clone(), gc.clone(), writer_id.clone());
177        WriteHandle {
178            cfg,
179            metrics,
180            machine,
181            gc,
182            compact,
183            blob,
184            isolated_runtime,
185            writer_id,
186            debug_state,
187            write_schemas,
188            upper,
189            expire_fn: Some(expire_fn),
190        }
191    }
192
193    pub fn from_read(read: &ReadHandle<K, V, T, D>, purpose: &str) -> Self {
196        Self::new(
197            read.cfg.clone(),
198            Arc::clone(&read.metrics),
199            read.machine.clone(),
200            read.gc.clone(),
201            Arc::clone(&read.blob),
202            WriterId::new(),
203            purpose,
204            read.read_schemas.clone(),
205        )
206    }
207
208    pub fn validate_part_bounds_on_write(&self) -> bool {
211        VALIDATE_PART_BOUNDS_ON_WRITE.get(&self.cfg) && VALIDATE_PART_BOUNDS_ON_READ.get(&self.cfg)
212    }
213
214    pub fn shard_id(&self) -> ShardId {
216        self.machine.shard_id()
217    }
218
219    pub fn schema_id(&self) -> Option<SchemaId> {
221        self.write_schemas.id
222    }
223
224    pub fn upper(&self) -> &Antichain<T> {
231        &self.upper
232    }
233
234    pub fn shared_upper(&self) -> Antichain<T> {
240        self.machine.applier.clone_upper()
241    }
242
243    #[instrument(level = "debug", fields(shard = %self.machine.shard_id()))]
249    pub async fn fetch_recent_upper(&mut self) -> &Antichain<T> {
250        self.machine
253            .applier
254            .fetch_upper(|current_upper| self.upper.clone_from(current_upper))
255            .await;
256        &self.upper
257    }
258
259    #[instrument(level = "trace", fields(shard = %self.machine.shard_id()))]
289    pub async fn append<SB, KB, VB, TB, DB, I>(
290        &mut self,
291        updates: I,
292        lower: Antichain<T>,
293        upper: Antichain<T>,
294    ) -> Result<Result<(), UpperMismatch<T>>, InvalidUsage<T>>
295    where
296        SB: Borrow<((KB, VB), TB, DB)>,
297        KB: Borrow<K>,
298        VB: Borrow<V>,
299        TB: Borrow<T>,
300        DB: Borrow<D>,
301        I: IntoIterator<Item = SB>,
302        D: Send + Sync,
303    {
304        let batch = self.batch(updates, lower.clone(), upper.clone()).await?;
305        self.append_batch(batch, lower, upper).await
306    }
307
308    #[instrument(level = "trace", fields(shard = %self.machine.shard_id()))]
337    pub async fn compare_and_append<SB, KB, VB, TB, DB, I>(
338        &mut self,
339        updates: I,
340        expected_upper: Antichain<T>,
341        new_upper: Antichain<T>,
342    ) -> Result<Result<(), UpperMismatch<T>>, InvalidUsage<T>>
343    where
344        SB: Borrow<((KB, VB), TB, DB)>,
345        KB: Borrow<K>,
346        VB: Borrow<V>,
347        TB: Borrow<T>,
348        DB: Borrow<D>,
349        I: IntoIterator<Item = SB>,
350        D: Send + Sync,
351    {
352        let mut batch = self
353            .batch(updates, expected_upper.clone(), new_upper.clone())
354            .await?;
355        match self
356            .compare_and_append_batch(&mut [&mut batch], expected_upper, new_upper, true)
357            .await
358        {
359            ok @ Ok(Ok(())) => ok,
360            err => {
361                batch.delete().await;
366                err
367            }
368        }
369    }
370
371    #[instrument(level = "trace", fields(shard = %self.machine.shard_id()))]
397    pub async fn append_batch(
398        &mut self,
399        mut batch: Batch<K, V, T, D>,
400        mut lower: Antichain<T>,
401        upper: Antichain<T>,
402    ) -> Result<Result<(), UpperMismatch<T>>, InvalidUsage<T>>
403    where
404        D: Send + Sync,
405    {
406        loop {
407            let res = self
408                .compare_and_append_batch(&mut [&mut batch], lower.clone(), upper.clone(), true)
409                .await?;
410            match res {
411                Ok(()) => {
412                    self.upper = upper;
413                    return Ok(Ok(()));
414                }
415                Err(mismatch) => {
416                    if PartialOrder::less_than(&mismatch.current, &lower) {
418                        self.upper.clone_from(&mismatch.current);
419
420                        batch.delete().await;
421
422                        return Ok(Err(mismatch));
423                    } else if PartialOrder::less_than(&mismatch.current, &upper) {
424                        lower = mismatch.current;
431                    } else {
432                        self.upper = mismatch.current;
434
435                        batch.delete().await;
439
440                        return Ok(Ok(()));
441                    }
442                }
443            }
444        }
445    }
446
447    #[instrument(level = "debug", fields(shard = %self.machine.shard_id()))]
481    pub async fn compare_and_append_batch(
482        &mut self,
483        batches: &mut [&mut Batch<K, V, T, D>],
484        expected_upper: Antichain<T>,
485        new_upper: Antichain<T>,
486        validate_part_bounds_on_write: bool,
487    ) -> Result<Result<(), UpperMismatch<T>>, InvalidUsage<T>>
488    where
489        D: Send + Sync,
490    {
491        for batch in batches.iter() {
492            if self.machine.shard_id() != batch.shard_id() {
493                return Err(InvalidUsage::BatchNotFromThisShard {
494                    batch_shard: batch.shard_id(),
495                    handle_shard: self.machine.shard_id(),
496                });
497            }
498            check_data_version(&self.cfg.build_version, &batch.version);
499            if self.cfg.build_version > batch.version {
500                info!(
501                    shard_id =? self.machine.shard_id(),
502                    batch_version =? batch.version,
503                    writer_version =? self.cfg.build_version,
504                    "Appending batch from the past. This is fine but should be rare. \
505                    TODO: Error on very old versions once the leaked blob detector exists."
506                )
507            }
508        }
509
510        let lower = expected_upper.clone();
511        let upper = new_upper;
512        let since = Antichain::from_elem(T::minimum());
513        let desc = Description::new(lower, upper, since);
514
515        let mut received_inline_backpressure = false;
516        let mut inline_batch_builder: Option<(_, BatchBuilder<K, V, T, D>)> = None;
523        let maintenance = loop {
524            let any_batch_rewrite = batches
525                .iter()
526                .any(|x| x.batch.parts.iter().any(|x| x.ts_rewrite().is_some()));
527            let (mut parts, mut num_updates, mut run_splits, mut run_metas) =
528                (vec![], 0, vec![], vec![]);
529            let mut key_storage = None;
530            let mut val_storage = None;
531            for batch in batches.iter() {
532                let () = validate_truncate_batch(
533                    &batch.batch,
534                    &desc,
535                    any_batch_rewrite,
536                    validate_part_bounds_on_write,
537                )?;
538                for (run_meta, run) in batch.batch.runs() {
539                    let start_index = parts.len();
540                    for part in run {
541                        if let (
542                            RunPart::Single(
543                                batch_part @ BatchPart::Inline {
544                                    updates,
545                                    ts_rewrite,
546                                    schema_id: _,
547                                    deprecated_schema_id: _,
548                                },
549                            ),
550                            Some((schema_cache, builder)),
551                        ) = (part, &mut inline_batch_builder)
552                        {
553                            let schema_migration = PartMigration::new(
554                                batch_part,
555                                self.write_schemas.clone(),
556                                schema_cache,
557                            )
558                            .await
559                            .expect("schemas for inline user part");
560
561                            let encoded_part = EncodedPart::from_inline(
562                                &crate::fetch::FetchConfig::from_persist_config(&self.cfg),
563                                &*self.metrics,
564                                self.metrics.read.compaction.clone(),
565                                desc.clone(),
566                                updates,
567                                ts_rewrite.as_ref(),
568                            );
569                            let mut fetched_part = FetchedPart::new(
570                                Arc::clone(&self.metrics),
571                                encoded_part,
572                                schema_migration,
573                                FetchBatchFilter::Compaction {
574                                    since: desc.since().clone(),
575                                },
576                                false,
577                                PartDecodeFormat::Arrow,
578                                None,
579                            );
580
581                            while let Some(((k, v), t, d)) =
582                                fetched_part.next_with_storage(&mut key_storage, &mut val_storage)
583                            {
584                                builder
585                                    .add(
586                                        &k.expect("decoded just-encoded key data"),
587                                        &v.expect("decoded just-encoded val data"),
588                                        &t,
589                                        &d,
590                                    )
591                                    .await
592                                    .expect("re-encoding just-decoded data");
593                            }
594                        } else {
595                            parts.push(part.clone())
596                        }
597                    }
598
599                    let end_index = parts.len();
600
601                    if start_index == end_index {
602                        continue;
603                    }
604
605                    if start_index != 0 {
607                        run_splits.push(start_index);
608                    }
609                    run_metas.push(run_meta.clone());
610                }
611                num_updates += batch.batch.len;
612            }
613
614            let mut flushed_inline_batch = if let Some((_, builder)) = inline_batch_builder.take() {
615                let mut finished = builder
616                    .finish(desc.upper().clone())
617                    .await
618                    .expect("invalid usage");
619                let cfg = BatchBuilderConfig::new(&self.cfg, self.shard_id());
620                finished
621                    .flush_to_blob(
622                        &cfg,
623                        &self.metrics.inline.backpressure,
624                        &self.isolated_runtime,
625                        &self.write_schemas,
626                    )
627                    .await;
628                Some(finished)
629            } else {
630                None
631            };
632
633            if let Some(batch) = &flushed_inline_batch {
634                for (run_meta, run) in batch.batch.runs() {
635                    assert!(run.len() > 0);
636                    let start_index = parts.len();
637                    if start_index != 0 {
638                        run_splits.push(start_index);
639                    }
640                    run_metas.push(run_meta.clone());
641                    parts.extend(run.iter().cloned())
642                }
643            }
644
645            let combined_batch =
646                HollowBatch::new(desc.clone(), parts, num_updates, run_metas, run_splits);
647            let heartbeat_timestamp = (self.cfg.now)();
648            let res = self
649                .machine
650                .compare_and_append(
651                    &combined_batch,
652                    &self.writer_id,
653                    &self.debug_state,
654                    heartbeat_timestamp,
655                )
656                .await;
657
658            match res {
659                CompareAndAppendRes::Success(_seqno, maintenance) => {
660                    self.upper.clone_from(desc.upper());
661                    for batch in batches.iter_mut() {
662                        batch.mark_consumed();
663                    }
664                    if let Some(batch) = &mut flushed_inline_batch {
665                        batch.mark_consumed();
666                    }
667                    break maintenance;
668                }
669                CompareAndAppendRes::InvalidUsage(invalid_usage) => {
670                    if let Some(batch) = flushed_inline_batch.take() {
671                        batch.delete().await;
672                    }
673                    return Err(invalid_usage);
674                }
675                CompareAndAppendRes::UpperMismatch(_seqno, current_upper) => {
676                    if let Some(batch) = flushed_inline_batch.take() {
677                        batch.delete().await;
678                    }
679                    self.upper.clone_from(¤t_upper);
682                    return Ok(Err(UpperMismatch {
683                        current: current_upper,
684                        expected: expected_upper,
685                    }));
686                }
687                CompareAndAppendRes::InlineBackpressure => {
688                    assert_eq!(received_inline_backpressure, false);
691                    received_inline_backpressure = true;
692                    if COMBINE_INLINE_WRITES.get(&self.cfg) {
693                        inline_batch_builder = Some((
694                            self.machine.applier.schema_cache(),
695                            self.builder(desc.lower().clone()),
696                        ));
697                        continue;
698                    }
699
700                    let cfg = BatchBuilderConfig::new(&self.cfg, self.shard_id());
701                    let flush_batches = batches
704                        .iter_mut()
705                        .map(|batch| async {
706                            batch
707                                .flush_to_blob(
708                                    &cfg,
709                                    &self.metrics.inline.backpressure,
710                                    &self.isolated_runtime,
711                                    &self.write_schemas,
712                                )
713                                .await
714                        })
715                        .collect::<FuturesUnordered<_>>();
716                    let () = flush_batches.collect::<()>().await;
717
718                    for batch in batches.iter() {
719                        assert_eq!(batch.batch.inline_bytes(), 0);
720                    }
721
722                    continue;
723                }
724            }
725        };
726
727        maintenance.start_performing(&self.machine, &self.gc, self.compact.as_ref());
728
729        Ok(Ok(()))
730    }
731
732    pub fn batch_from_transmittable_batch(&self, batch: ProtoBatch) -> Batch<K, V, T, D> {
735        let shard_id: ShardId = batch
736            .shard_id
737            .into_rust()
738            .expect("valid transmittable batch");
739        assert_eq!(shard_id, self.machine.shard_id());
740
741        let ret = Batch {
742            batch_delete_enabled: BATCH_DELETE_ENABLED.get(&self.cfg),
743            metrics: Arc::clone(&self.metrics),
744            shard_metrics: Arc::clone(&self.machine.applier.shard_metrics),
745            version: Version::parse(&batch.version).expect("valid transmittable batch"),
746            batch: batch
747                .batch
748                .into_rust_if_some("ProtoBatch::batch")
749                .expect("valid transmittable batch"),
750            blob: Arc::clone(&self.blob),
751            _phantom: std::marker::PhantomData,
752        };
753        assert_eq!(ret.shard_id(), self.machine.shard_id());
754        ret
755    }
756
757    pub fn builder(&self, lower: Antichain<T>) -> BatchBuilder<K, V, T, D> {
770        Self::builder_inner(
771            &self.cfg,
772            CompactConfig::new(&self.cfg, self.shard_id()),
773            Arc::clone(&self.metrics),
774            Arc::clone(&self.machine.applier.shard_metrics),
775            &self.metrics.user,
776            Arc::clone(&self.isolated_runtime),
777            Arc::clone(&self.blob),
778            self.shard_id(),
779            self.write_schemas.clone(),
780            lower,
781        )
782    }
783
784    pub(crate) fn builder_inner(
787        persist_cfg: &PersistConfig,
788        compact_cfg: CompactConfig,
789        metrics: Arc<Metrics>,
790        shard_metrics: Arc<ShardMetrics>,
791        user_batch_metrics: &BatchWriteMetrics,
792        isolated_runtime: Arc<IsolatedRuntime>,
793        blob: Arc<dyn Blob>,
794        shard_id: ShardId,
795        schemas: Schemas<K, V>,
796        lower: Antichain<T>,
797    ) -> BatchBuilder<K, V, T, D> {
798        let parts = if let Some(max_runs) = compact_cfg.batch.max_runs {
799            BatchParts::new_compacting::<K, V, D>(
800                compact_cfg,
801                Description::new(
802                    lower.clone(),
803                    Antichain::new(),
804                    Antichain::from_elem(T::minimum()),
805                ),
806                max_runs,
807                Arc::clone(&metrics),
808                shard_metrics,
809                shard_id,
810                Arc::clone(&blob),
811                isolated_runtime,
812                user_batch_metrics,
813                schemas.clone(),
814            )
815        } else {
816            BatchParts::new_ordered::<D>(
817                compact_cfg.batch,
818                RunOrder::Unordered,
819                Arc::clone(&metrics),
820                shard_metrics,
821                shard_id,
822                Arc::clone(&blob),
823                isolated_runtime,
824                user_batch_metrics,
825            )
826        };
827        let builder = BatchBuilderInternal::new(
828            BatchBuilderConfig::new(persist_cfg, shard_id),
829            parts,
830            metrics,
831            schemas,
832            blob,
833            shard_id,
834            persist_cfg.build_version.clone(),
835        );
836        BatchBuilder::new(
837            builder,
838            Description::new(lower, Antichain::new(), Antichain::from_elem(T::minimum())),
839        )
840    }
841
842    #[instrument(level = "trace", fields(shard = %self.machine.shard_id()))]
845    pub async fn batch<SB, KB, VB, TB, DB, I>(
846        &mut self,
847        updates: I,
848        lower: Antichain<T>,
849        upper: Antichain<T>,
850    ) -> Result<Batch<K, V, T, D>, InvalidUsage<T>>
851    where
852        SB: Borrow<((KB, VB), TB, DB)>,
853        KB: Borrow<K>,
854        VB: Borrow<V>,
855        TB: Borrow<T>,
856        DB: Borrow<D>,
857        I: IntoIterator<Item = SB>,
858    {
859        let iter = updates.into_iter();
860
861        let mut builder = self.builder(lower.clone());
862
863        for update in iter {
864            let ((k, v), t, d) = update.borrow();
865            let (k, v, t, d) = (k.borrow(), v.borrow(), t.borrow(), d.borrow());
866            match builder.add(k, v, t, d).await {
867                Ok(Added::Record | Added::RecordAndParts) => (),
868                Err(invalid_usage) => return Err(invalid_usage),
869            }
870        }
871
872        builder.finish(upper.clone()).await
873    }
874
875    pub async fn wait_for_upper_past(&mut self, frontier: &Antichain<T>) {
877        let mut watch = self.machine.applier.watch();
878        let batch = self
879            .machine
880            .next_listen_batch(frontier, &mut watch, None, None)
881            .await;
882        if PartialOrder::less_than(&self.upper, batch.desc.upper()) {
883            self.upper.clone_from(batch.desc.upper());
884        }
885        assert!(PartialOrder::less_than(frontier, &self.upper));
886    }
887
888    #[instrument(level = "debug", fields(shard = %self.machine.shard_id()))]
897    pub async fn expire(mut self) {
898        let Some(expire_fn) = self.expire_fn.take() else {
899            return;
900        };
901        expire_fn.0().await;
902    }
903
904    fn expire_fn(
905        machine: Machine<K, V, T, D>,
906        gc: GarbageCollector<K, V, T, D>,
907        writer_id: WriterId,
908    ) -> ExpireFn {
909        ExpireFn(Box::new(move || {
910            Box::pin(async move {
911                let (_, maintenance) = machine.expire_writer(&writer_id).await;
912                maintenance.start_performing(&machine, &gc);
913            })
914        }))
915    }
916
917    #[cfg(test)]
919    #[track_caller]
920    pub async fn expect_append<L, U>(&mut self, updates: &[((K, V), T, D)], lower: L, new_upper: U)
921    where
922        L: Into<Antichain<T>>,
923        U: Into<Antichain<T>>,
924        D: Send + Sync,
925    {
926        self.append(updates.iter(), lower.into(), new_upper.into())
927            .await
928            .expect("invalid usage")
929            .expect("unexpected upper");
930    }
931
932    #[cfg(test)]
935    #[track_caller]
936    pub async fn expect_compare_and_append(
937        &mut self,
938        updates: &[((K, V), T, D)],
939        expected_upper: T,
940        new_upper: T,
941    ) where
942        D: Send + Sync,
943    {
944        self.compare_and_append(
945            updates.iter().map(|((k, v), t, d)| ((k, v), t, d)),
946            Antichain::from_elem(expected_upper),
947            Antichain::from_elem(new_upper),
948        )
949        .await
950        .expect("invalid usage")
951        .expect("unexpected upper")
952    }
953
954    #[cfg(test)]
957    #[track_caller]
958    pub async fn expect_compare_and_append_batch(
959        &mut self,
960        batches: &mut [&mut Batch<K, V, T, D>],
961        expected_upper: T,
962        new_upper: T,
963    ) {
964        self.compare_and_append_batch(
965            batches,
966            Antichain::from_elem(expected_upper),
967            Antichain::from_elem(new_upper),
968            true,
969        )
970        .await
971        .expect("invalid usage")
972        .expect("unexpected upper")
973    }
974
975    #[cfg(test)]
977    #[track_caller]
978    pub async fn expect_batch(
979        &mut self,
980        updates: &[((K, V), T, D)],
981        lower: T,
982        upper: T,
983    ) -> Batch<K, V, T, D> {
984        self.batch(
985            updates.iter(),
986            Antichain::from_elem(lower),
987            Antichain::from_elem(upper),
988        )
989        .await
990        .expect("invalid usage")
991    }
992}
993
994impl<K: Codec, V: Codec, T, D> Drop for WriteHandle<K, V, T, D> {
995    fn drop(&mut self) {
996        let Some(expire_fn) = self.expire_fn.take() else {
997            return;
998        };
999        let handle = match Handle::try_current() {
1000            Ok(x) => x,
1001            Err(_) => {
1002                warn!(
1003                    "WriteHandle {} dropped without being explicitly expired, falling back to lease timeout",
1004                    self.writer_id
1005                );
1006                return;
1007            }
1008        };
1009        let expire_span = debug_span!("drop::expire");
1015        handle.spawn_named(
1016            || format!("WriteHandle::expire ({})", self.writer_id),
1017            expire_fn.0().instrument(expire_span),
1018        );
1019    }
1020}
1021
1022#[cfg(test)]
1023mod tests {
1024    use std::str::FromStr;
1025    use std::sync::mpsc;
1026
1027    use differential_dataflow::consolidation::consolidate_updates;
1028    use futures_util::FutureExt;
1029    use mz_dyncfg::ConfigUpdates;
1030    use mz_ore::collections::CollectionExt;
1031    use mz_ore::task;
1032    use serde_json::json;
1033
1034    use crate::cache::PersistClientCache;
1035    use crate::tests::{all_ok, new_test_client};
1036    use crate::{PersistLocation, ShardId};
1037
1038    use super::*;
1039
1040    #[mz_persist_proc::test(tokio::test)]
1041    #[cfg_attr(miri, ignore)] async fn empty_batches(dyncfgs: ConfigUpdates) {
1043        let data = [
1044            (("1".to_owned(), "one".to_owned()), 1, 1),
1045            (("2".to_owned(), "two".to_owned()), 2, 1),
1046            (("3".to_owned(), "three".to_owned()), 3, 1),
1047        ];
1048
1049        let (mut write, _) = new_test_client(&dyncfgs)
1050            .await
1051            .expect_open::<String, String, u64, i64>(ShardId::new())
1052            .await;
1053        let blob = Arc::clone(&write.blob);
1054
1055        let mut upper = 3;
1057        write.expect_append(&data[..2], vec![0], vec![upper]).await;
1058
1059        let mut count_before = 0;
1061        blob.list_keys_and_metadata("", &mut |_| {
1062            count_before += 1;
1063        })
1064        .await
1065        .expect("list_keys failed");
1066        for _ in 0..5 {
1067            let new_upper = upper + 1;
1068            write.expect_compare_and_append(&[], upper, new_upper).await;
1069            upper = new_upper;
1070        }
1071        let mut count_after = 0;
1072        blob.list_keys_and_metadata("", &mut |_| {
1073            count_after += 1;
1074        })
1075        .await
1076        .expect("list_keys failed");
1077        assert_eq!(count_after, count_before);
1078    }
1079
1080    #[mz_persist_proc::test(tokio::test)]
1081    #[cfg_attr(miri, ignore)] async fn compare_and_append_batch_multi(dyncfgs: ConfigUpdates) {
1083        let data0 = vec![
1084            (("1".to_owned(), "one".to_owned()), 1, 1),
1085            (("2".to_owned(), "two".to_owned()), 2, 1),
1086            (("4".to_owned(), "four".to_owned()), 4, 1),
1087        ];
1088        let data1 = vec![
1089            (("1".to_owned(), "one".to_owned()), 1, 1),
1090            (("2".to_owned(), "two".to_owned()), 2, 1),
1091            (("3".to_owned(), "three".to_owned()), 3, 1),
1092        ];
1093
1094        let (mut write, mut read) = new_test_client(&dyncfgs)
1095            .await
1096            .expect_open::<String, String, u64, i64>(ShardId::new())
1097            .await;
1098
1099        let mut batch0 = write.expect_batch(&data0, 0, 5).await;
1100        let mut batch1 = write.expect_batch(&data1, 0, 4).await;
1101
1102        write
1103            .expect_compare_and_append_batch(&mut [&mut batch0, &mut batch1], 0, 4)
1104            .await;
1105
1106        let batch = write
1107            .machine
1108            .snapshot(&Antichain::from_elem(3))
1109            .await
1110            .expect("just wrote this")
1111            .into_element();
1112
1113        assert!(batch.runs().count() >= 2);
1114
1115        let expected = vec![
1116            (("1".to_owned(), "one".to_owned()), 1, 2),
1117            (("2".to_owned(), "two".to_owned()), 2, 2),
1118            (("3".to_owned(), "three".to_owned()), 3, 1),
1119        ];
1120        let mut actual = read.expect_snapshot_and_fetch(3).await;
1121        consolidate_updates(&mut actual);
1122        assert_eq!(actual, all_ok(&expected, 3));
1123    }
1124
1125    #[mz_ore::test]
1126    fn writer_id_human_readable_serde() {
1127        #[derive(Debug, Serialize, Deserialize)]
1128        struct Container {
1129            writer_id: WriterId,
1130        }
1131
1132        let id = WriterId::from_str("w00000000-1234-5678-0000-000000000000").expect("valid id");
1134        assert_eq!(
1135            id,
1136            serde_json::from_value(serde_json::to_value(id.clone()).expect("serializable"))
1137                .expect("deserializable")
1138        );
1139
1140        assert_eq!(
1142            id,
1143            serde_json::from_str("\"w00000000-1234-5678-0000-000000000000\"")
1144                .expect("deserializable")
1145        );
1146
1147        let json = json!({ "writer_id": id });
1149        assert_eq!(
1150            "{\"writer_id\":\"w00000000-1234-5678-0000-000000000000\"}",
1151            &json.to_string()
1152        );
1153        let container: Container = serde_json::from_value(json).expect("deserializable");
1154        assert_eq!(container.writer_id, id);
1155    }
1156
1157    #[mz_persist_proc::test(tokio::test)]
1158    #[cfg_attr(miri, ignore)] async fn hollow_batch_roundtrip(dyncfgs: ConfigUpdates) {
1160        let data = vec![
1161            (("1".to_owned(), "one".to_owned()), 1, 1),
1162            (("2".to_owned(), "two".to_owned()), 2, 1),
1163            (("3".to_owned(), "three".to_owned()), 3, 1),
1164        ];
1165
1166        let (mut write, mut read) = new_test_client(&dyncfgs)
1167            .await
1168            .expect_open::<String, String, u64, i64>(ShardId::new())
1169            .await;
1170
1171        let batch = write.expect_batch(&data, 0, 4).await;
1176        let hollow_batch = batch.into_transmittable_batch();
1177        let mut rehydrated_batch = write.batch_from_transmittable_batch(hollow_batch);
1178
1179        write
1180            .expect_compare_and_append_batch(&mut [&mut rehydrated_batch], 0, 4)
1181            .await;
1182
1183        let expected = vec![
1184            (("1".to_owned(), "one".to_owned()), 1, 1),
1185            (("2".to_owned(), "two".to_owned()), 2, 1),
1186            (("3".to_owned(), "three".to_owned()), 3, 1),
1187        ];
1188        let mut actual = read.expect_snapshot_and_fetch(3).await;
1189        consolidate_updates(&mut actual);
1190        assert_eq!(actual, all_ok(&expected, 3));
1191    }
1192
1193    #[mz_persist_proc::test(tokio::test)]
1194    #[cfg_attr(miri, ignore)] async fn wait_for_upper_past(dyncfgs: ConfigUpdates) {
1196        let client = new_test_client(&dyncfgs).await;
1197        let (mut write, _) = client.expect_open::<(), (), u64, i64>(ShardId::new()).await;
1198        let five = Antichain::from_elem(5);
1199
1200        assert_eq!(write.wait_for_upper_past(&five).now_or_never(), None);
1202
1203        write
1205            .expect_compare_and_append(&[(((), ()), 1, 1)], 0, 5)
1206            .await;
1207        assert_eq!(write.wait_for_upper_past(&five).now_or_never(), None);
1208
1209        write
1211            .expect_compare_and_append(&[(((), ()), 5, 1)], 5, 7)
1212            .await;
1213        assert_eq!(write.wait_for_upper_past(&five).now_or_never(), Some(()));
1214        assert_eq!(write.upper(), &Antichain::from_elem(7));
1215
1216        assert_eq!(
1219            write
1220                .wait_for_upper_past(&Antichain::from_elem(2))
1221                .now_or_never(),
1222            Some(())
1223        );
1224        assert_eq!(write.upper(), &Antichain::from_elem(7));
1225    }
1226
1227    #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
1228    #[cfg_attr(miri, ignore)] async fn fetch_recent_upper_linearized() {
1230        type Timestamp = u64;
1231        let max_upper = 1000;
1232
1233        let shard_id = ShardId::new();
1234        let mut clients = PersistClientCache::new_no_metrics();
1235        let upper_writer_client = clients.open(PersistLocation::new_in_mem()).await.unwrap();
1236        let (mut upper_writer, _) = upper_writer_client
1237            .expect_open::<(), (), Timestamp, i64>(shard_id)
1238            .await;
1239        clients.clear_state_cache();
1242        let upper_reader_client = clients.open(PersistLocation::new_in_mem()).await.unwrap();
1243        let (mut upper_reader, _) = upper_reader_client
1244            .expect_open::<(), (), Timestamp, i64>(shard_id)
1245            .await;
1246        let (tx, rx) = mpsc::channel();
1247
1248        let task = task::spawn(|| "upper-reader", async move {
1249            let mut upper = Timestamp::MIN;
1250
1251            while upper < max_upper {
1252                while let Ok(new_upper) = rx.try_recv() {
1253                    upper = new_upper;
1254                }
1255
1256                let recent_upper = upper_reader
1257                    .fetch_recent_upper()
1258                    .await
1259                    .as_option()
1260                    .cloned()
1261                    .expect("u64 is totally ordered and the shard is not finalized");
1262                assert!(
1263                    recent_upper >= upper,
1264                    "recent upper {recent_upper:?} is less than known upper {upper:?}"
1265                );
1266            }
1267        });
1268
1269        for upper in Timestamp::MIN..max_upper {
1270            let next_upper = upper + 1;
1271            upper_writer
1272                .expect_compare_and_append(&[], upper, next_upper)
1273                .await;
1274            tx.send(next_upper).expect("send failed");
1275        }
1276
1277        task.await.expect("await failed");
1278    }
1279}