mz_persist_client/internal/
state_diff.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::cmp::Ordering;
11use std::collections::BTreeMap;
12use std::fmt::Debug;
13use std::sync::Arc;
14
15use bytes::{Bytes, BytesMut};
16use differential_dataflow::lattice::Lattice;
17use differential_dataflow::trace::Description;
18use mz_ore::assert_none;
19use mz_ore::cast::CastFrom;
20use mz_persist::location::{SeqNo, VersionedData};
21use mz_persist_types::Codec64;
22use mz_persist_types::schema::SchemaId;
23use mz_proto::TryFromProtoError;
24use timely::PartialOrder;
25use timely::progress::{Antichain, Timestamp};
26use tracing::debug;
27
28use crate::critical::CriticalReaderId;
29use crate::internal::paths::PartialRollupKey;
30use crate::internal::state::{
31    CriticalReaderState, EncodedSchemas, HollowBatch, HollowBlobRef, HollowRollup,
32    LeasedReaderState, ProtoStateField, ProtoStateFieldDiffType, ProtoStateFieldDiffs, State,
33    StateCollections, WriterState,
34};
35use crate::internal::trace::{FueledMergeRes, SpineId, ThinMerge, ThinSpineBatch, Trace};
36use crate::read::LeasedReaderId;
37use crate::write::WriterId;
38use crate::{Metrics, PersistConfig, ShardId};
39
40use StateFieldValDiff::*;
41
42use super::state::{ActiveGc, ActiveRollup};
43
44#[derive(Clone, Debug)]
45#[cfg_attr(any(test, debug_assertions), derive(PartialEq))]
46pub enum StateFieldValDiff<V> {
47    Insert(V),
48    Update(V, V),
49    Delete(V),
50}
51
52#[derive(Clone)]
53#[cfg_attr(any(test, debug_assertions), derive(PartialEq))]
54pub struct StateFieldDiff<K, V> {
55    pub key: K,
56    pub val: StateFieldValDiff<V>,
57}
58
59impl<K: Debug, V: Debug> std::fmt::Debug for StateFieldDiff<K, V> {
60    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61        f.debug_struct("StateFieldDiff")
62            // In the cases we've seen in the wild, it's been more useful to
63            // have the val printed first.
64            .field("val", &self.val)
65            .field("key", &self.key)
66            .finish()
67    }
68}
69
70#[derive(Debug)]
71#[cfg_attr(any(test, debug_assertions), derive(Clone, PartialEq))]
72pub struct StateDiff<T> {
73    pub(crate) applier_version: semver::Version,
74    pub(crate) seqno_from: SeqNo,
75    pub(crate) seqno_to: SeqNo,
76    pub(crate) walltime_ms: u64,
77    pub(crate) latest_rollup_key: PartialRollupKey,
78    pub(crate) rollups: Vec<StateFieldDiff<SeqNo, HollowRollup>>,
79    pub(crate) active_rollup: Vec<StateFieldDiff<(), ActiveRollup>>,
80    pub(crate) active_gc: Vec<StateFieldDiff<(), ActiveGc>>,
81    pub(crate) hostname: Vec<StateFieldDiff<(), String>>,
82    pub(crate) last_gc_req: Vec<StateFieldDiff<(), SeqNo>>,
83    pub(crate) leased_readers: Vec<StateFieldDiff<LeasedReaderId, LeasedReaderState<T>>>,
84    pub(crate) critical_readers: Vec<StateFieldDiff<CriticalReaderId, CriticalReaderState<T>>>,
85    pub(crate) writers: Vec<StateFieldDiff<WriterId, WriterState<T>>>,
86    pub(crate) schemas: Vec<StateFieldDiff<SchemaId, EncodedSchemas>>,
87    pub(crate) since: Vec<StateFieldDiff<(), Antichain<T>>>,
88    pub(crate) legacy_batches: Vec<StateFieldDiff<HollowBatch<T>, ()>>,
89    pub(crate) hollow_batches: Vec<StateFieldDiff<SpineId, Arc<HollowBatch<T>>>>,
90    pub(crate) spine_batches: Vec<StateFieldDiff<SpineId, ThinSpineBatch<T>>>,
91    pub(crate) merges: Vec<StateFieldDiff<SpineId, ThinMerge<T>>>,
92}
93
94impl<T: Timestamp + Codec64> StateDiff<T> {
95    pub fn new(
96        applier_version: semver::Version,
97        seqno_from: SeqNo,
98        seqno_to: SeqNo,
99        walltime_ms: u64,
100        latest_rollup_key: PartialRollupKey,
101    ) -> Self {
102        StateDiff {
103            applier_version,
104            seqno_from,
105            seqno_to,
106            walltime_ms,
107            latest_rollup_key,
108            rollups: Vec::default(),
109            active_rollup: Vec::default(),
110            active_gc: Vec::default(),
111            hostname: Vec::default(),
112            last_gc_req: Vec::default(),
113            leased_readers: Vec::default(),
114            critical_readers: Vec::default(),
115            writers: Vec::default(),
116            schemas: Vec::default(),
117            since: Vec::default(),
118            legacy_batches: Vec::default(),
119            hollow_batches: Vec::default(),
120            spine_batches: Vec::default(),
121            merges: Vec::default(),
122        }
123    }
124
125    pub fn referenced_batches(&self) -> impl Iterator<Item = StateFieldValDiff<&HollowBatch<T>>> {
126        let legacy_batches = self
127            .legacy_batches
128            .iter()
129            .filter_map(|diff| match diff.val {
130                Insert(()) => Some(Insert(&diff.key)),
131                Update((), ()) => None, // Ignoring a noop diff.
132                Delete(()) => Some(Delete(&diff.key)),
133            });
134        let hollow_batches = self.hollow_batches.iter().map(|diff| match &diff.val {
135            Insert(batch) => Insert(&**batch),
136            Update(before, after) => Update(&**before, &**after),
137            Delete(batch) => Delete(&**batch),
138        });
139        legacy_batches.chain(hollow_batches)
140    }
141}
142
143impl<T: Timestamp + Lattice + Codec64> StateDiff<T> {
144    pub fn from_diff(from: &State<T>, to: &State<T>) -> Self {
145        // Deconstruct from and to so we get a compile failure if new
146        // fields are added.
147        let State {
148            applier_version: _,
149            shard_id: from_shard_id,
150            seqno: from_seqno,
151            hostname: from_hostname,
152            walltime_ms: _, // Intentionally unused
153            collections:
154                StateCollections {
155                    last_gc_req: from_last_gc_req,
156                    rollups: from_rollups,
157                    active_rollup: from_active_rollup,
158                    active_gc: from_active_gc,
159                    leased_readers: from_leased_readers,
160                    critical_readers: from_critical_readers,
161                    writers: from_writers,
162                    schemas: from_schemas,
163                    trace: from_trace,
164                },
165        } = from;
166        let State {
167            applier_version: to_applier_version,
168            shard_id: to_shard_id,
169            seqno: to_seqno,
170            walltime_ms: to_walltime_ms,
171            hostname: to_hostname,
172            collections:
173                StateCollections {
174                    last_gc_req: to_last_gc_req,
175                    rollups: to_rollups,
176                    active_rollup: to_active_rollup,
177                    active_gc: to_active_gc,
178                    leased_readers: to_leased_readers,
179                    critical_readers: to_critical_readers,
180                    writers: to_writers,
181                    schemas: to_schemas,
182                    trace: to_trace,
183                },
184        } = to;
185        assert_eq!(from_shard_id, to_shard_id);
186
187        let (_, latest_rollup) = to.latest_rollup();
188        let mut diffs = Self::new(
189            to_applier_version.clone(),
190            *from_seqno,
191            *to_seqno,
192            *to_walltime_ms,
193            latest_rollup.key.clone(),
194        );
195        diff_field_single(from_hostname, to_hostname, &mut diffs.hostname);
196        diff_field_single(from_last_gc_req, to_last_gc_req, &mut diffs.last_gc_req);
197        diff_field_sorted_iter(
198            from_active_rollup.iter().map(|r| (&(), r)),
199            to_active_rollup.iter().map(|r| (&(), r)),
200            &mut diffs.active_rollup,
201        );
202        diff_field_sorted_iter(
203            from_active_gc.iter().map(|g| (&(), g)),
204            to_active_gc.iter().map(|g| (&(), g)),
205            &mut diffs.active_gc,
206        );
207        diff_field_sorted_iter(from_rollups.iter(), to_rollups, &mut diffs.rollups);
208        diff_field_sorted_iter(
209            from_leased_readers.iter(),
210            to_leased_readers,
211            &mut diffs.leased_readers,
212        );
213        diff_field_sorted_iter(
214            from_critical_readers.iter(),
215            to_critical_readers,
216            &mut diffs.critical_readers,
217        );
218        diff_field_sorted_iter(from_writers.iter(), to_writers, &mut diffs.writers);
219        diff_field_sorted_iter(from_schemas.iter(), to_schemas, &mut diffs.schemas);
220        diff_field_single(from_trace.since(), to_trace.since(), &mut diffs.since);
221
222        let from_flat = from_trace.flatten();
223        let to_flat = to_trace.flatten();
224        diff_field_sorted_iter(
225            from_flat.legacy_batches.iter().map(|(k, v)| (&**k, v)),
226            to_flat.legacy_batches.iter().map(|(k, v)| (&**k, v)),
227            &mut diffs.legacy_batches,
228        );
229        diff_field_sorted_iter(
230            from_flat.hollow_batches.iter(),
231            to_flat.hollow_batches.iter(),
232            &mut diffs.hollow_batches,
233        );
234        diff_field_sorted_iter(
235            from_flat.spine_batches.iter(),
236            to_flat.spine_batches.iter(),
237            &mut diffs.spine_batches,
238        );
239        diff_field_sorted_iter(
240            from_flat.merges.iter(),
241            to_flat.merges.iter(),
242            &mut diffs.merges,
243        );
244        diffs
245    }
246
247    pub(crate) fn blob_inserts(&self) -> impl Iterator<Item = HollowBlobRef<T>> {
248        let batches = self
249            .referenced_batches()
250            .filter_map(|spine_diff| match spine_diff {
251                Insert(b) | Update(_, b) => Some(HollowBlobRef::Batch(b)),
252                Delete(_) => None, // No-op
253            });
254        let rollups = self
255            .rollups
256            .iter()
257            .filter_map(|rollups_diff| match &rollups_diff.val {
258                StateFieldValDiff::Insert(x) | StateFieldValDiff::Update(_, x) => {
259                    Some(HollowBlobRef::Rollup(x))
260                }
261                StateFieldValDiff::Delete(_) => None, // No-op
262            });
263        batches.chain(rollups)
264    }
265
266    pub(crate) fn blob_deletes(&self) -> impl Iterator<Item = HollowBlobRef<T>> {
267        let batches = self
268            .referenced_batches()
269            .filter_map(|spine_diff| match spine_diff {
270                Insert(_) => None,
271                Update(a, _) | Delete(a) => Some(HollowBlobRef::Batch(a)),
272            });
273        let rollups = self
274            .rollups
275            .iter()
276            .filter_map(|rollups_diff| match &rollups_diff.val {
277                Insert(_) => None,
278                Update(a, _) | Delete(a) => Some(HollowBlobRef::Rollup(a)),
279            });
280        batches.chain(rollups)
281    }
282
283    #[cfg(any(test, debug_assertions))]
284    #[allow(dead_code)]
285    pub fn validate_roundtrip<K, V, D>(
286        metrics: &Metrics,
287        from_state: &crate::internal::state::TypedState<K, V, T, D>,
288        diff: &Self,
289        to_state: &crate::internal::state::TypedState<K, V, T, D>,
290    ) -> Result<(), String>
291    where
292        K: mz_persist_types::Codec + std::fmt::Debug,
293        V: mz_persist_types::Codec + std::fmt::Debug,
294        D: differential_dataflow::difference::Semigroup + Codec64,
295    {
296        use mz_proto::RustType;
297        use prost::Message;
298
299        use crate::internal::state::ProtoStateDiff;
300
301        let mut roundtrip_state = from_state.clone(
302            from_state.applier_version.clone(),
303            from_state.hostname.clone(),
304        );
305        roundtrip_state.apply_diff(metrics, diff.clone())?;
306
307        if &roundtrip_state != to_state {
308            // The weird spacing in this format string is so they all line up
309            // when printed out.
310            return Err(format!(
311                "state didn't roundtrip\n  from_state {:?}\n  to_state   {:?}\n  rt_state   {:?}\n  diff       {:?}\n",
312                from_state, to_state, roundtrip_state, diff
313            ));
314        }
315
316        let encoded_diff = diff.into_proto().encode_to_vec();
317        let roundtrip_diff = Self::from_proto(
318            ProtoStateDiff::decode(encoded_diff.as_slice()).map_err(|err| err.to_string())?,
319        )
320        .map_err(|err| err.to_string())?;
321
322        if &roundtrip_diff != diff {
323            // The weird spacing in this format string is so they all line up
324            // when printed out.
325            return Err(format!(
326                "diff didn't roundtrip\n  diff    {:?}\n  rt_diff {:?}",
327                diff, roundtrip_diff
328            ));
329        }
330
331        Ok(())
332    }
333}
334
335impl<T: Timestamp + Lattice + Codec64> State<T> {
336    pub fn apply_encoded_diffs<'a, I: IntoIterator<Item = &'a VersionedData>>(
337        &mut self,
338        cfg: &PersistConfig,
339        metrics: &Metrics,
340        diffs: I,
341    ) {
342        let mut state_seqno = self.seqno;
343        let diffs = diffs.into_iter().filter_map(move |x| {
344            if x.seqno != state_seqno.next() {
345                // No-op.
346                return None;
347            }
348            let data = x.data.clone();
349            let diff = metrics
350                .codecs
351                .state_diff
352                // Note: `x.data` is a `Bytes`, so cloning just increments a ref count
353                .decode(|| StateDiff::decode(&cfg.build_version, x.data.clone()));
354            assert_eq!(diff.seqno_from, state_seqno);
355            state_seqno = diff.seqno_to;
356            Some((diff, data))
357        });
358        self.apply_diffs(metrics, diffs);
359    }
360}
361
362impl<T: Timestamp + Lattice + Codec64> State<T> {
363    pub fn apply_diffs<I: IntoIterator<Item = (StateDiff<T>, Bytes)>>(
364        &mut self,
365        metrics: &Metrics,
366        diffs: I,
367    ) {
368        for (diff, data) in diffs {
369            // TODO: This could special-case batch apply for diffs where it's
370            // more efficient (in particular, spine batches that hit the slow
371            // path).
372            match self.apply_diff(metrics, diff) {
373                Ok(()) => {}
374                Err(err) => {
375                    // Having the full diff in the error message is critical for debugging any
376                    // issues that may arise from diff application. We pass along the original
377                    // Bytes it decoded from just so we can decode in this error path, while
378                    // avoiding any extraneous clones in the expected Ok path.
379                    let diff = StateDiff::<T>::decode(&self.applier_version, data);
380                    panic!(
381                        "state diff should apply cleanly: {} diff {:?} state {:?}",
382                        err, diff, self
383                    )
384                }
385            }
386        }
387    }
388
389    // Intentionally not even pub(crate) because all callers should use
390    // [Self::apply_diffs].
391    pub(super) fn apply_diff(
392        &mut self,
393        metrics: &Metrics,
394        diff: StateDiff<T>,
395    ) -> Result<(), String> {
396        // Deconstruct diff so we get a compile failure if new fields are added.
397        let StateDiff {
398            applier_version: diff_applier_version,
399            seqno_from: diff_seqno_from,
400            seqno_to: diff_seqno_to,
401            walltime_ms: diff_walltime_ms,
402            latest_rollup_key: _,
403            rollups: diff_rollups,
404            active_rollup: diff_active_rollup,
405            active_gc: diff_active_gc,
406            hostname: diff_hostname,
407            last_gc_req: diff_last_gc_req,
408            leased_readers: diff_leased_readers,
409            critical_readers: diff_critical_readers,
410            writers: diff_writers,
411            schemas: diff_schemas,
412            since: diff_since,
413            legacy_batches: diff_legacy_batches,
414            hollow_batches: diff_hollow_batches,
415            spine_batches: diff_spine_batches,
416            merges: diff_merges,
417        } = diff;
418        if self.seqno == diff_seqno_to {
419            return Ok(());
420        }
421        if self.seqno != diff_seqno_from {
422            return Err(format!(
423                "could not apply diff {} -> {} to state {}",
424                diff_seqno_from, diff_seqno_to, self.seqno
425            ));
426        }
427        self.seqno = diff_seqno_to;
428        self.applier_version = diff_applier_version;
429        self.walltime_ms = diff_walltime_ms;
430        force_apply_diffs_single(
431            &self.shard_id,
432            diff_seqno_to,
433            "hostname",
434            diff_hostname,
435            &mut self.hostname,
436            metrics,
437        )?;
438
439        // Deconstruct collections so we get a compile failure if new fields are
440        // added.
441        let StateCollections {
442            last_gc_req,
443            rollups,
444            active_rollup,
445            active_gc,
446            leased_readers,
447            critical_readers,
448            writers,
449            schemas,
450            trace,
451        } = &mut self.collections;
452
453        apply_diffs_map("rollups", diff_rollups, rollups)?;
454        apply_diffs_single("last_gc_req", diff_last_gc_req, last_gc_req)?;
455        apply_diffs_single_option("active_rollup", diff_active_rollup, active_rollup)?;
456        apply_diffs_single_option("active_gc", diff_active_gc, active_gc)?;
457        apply_diffs_map("leased_readers", diff_leased_readers, leased_readers)?;
458        apply_diffs_map("critical_readers", diff_critical_readers, critical_readers)?;
459        apply_diffs_map("writers", diff_writers, writers)?;
460        apply_diffs_map("schemas", diff_schemas, schemas)?;
461
462        let structure_unchanged = diff_hollow_batches.is_empty()
463            && diff_spine_batches.is_empty()
464            && diff_merges.is_empty();
465        let spine_unchanged =
466            diff_since.is_empty() && diff_legacy_batches.is_empty() && structure_unchanged;
467
468        if spine_unchanged {
469            return Ok(());
470        }
471
472        let mut flat = if trace.roundtrip_structure {
473            metrics.state.apply_spine_flattened.inc();
474            let mut flat = trace.flatten();
475            apply_diffs_single("since", diff_since, &mut flat.since)?;
476            apply_diffs_map(
477                "legacy_batches",
478                diff_legacy_batches
479                    .into_iter()
480                    .map(|StateFieldDiff { key, val }| StateFieldDiff {
481                        key: Arc::new(key),
482                        val,
483                    }),
484                &mut flat.legacy_batches,
485            )?;
486            Some(flat)
487        } else {
488            for x in diff_since {
489                match x.val {
490                    Update(from, to) => {
491                        if trace.since() != &from {
492                            return Err(format!(
493                                "since update didn't match: {:?} vs {:?}",
494                                self.collections.trace.since(),
495                                &from
496                            ));
497                        }
498                        trace.downgrade_since(&to);
499                    }
500                    Insert(_) => return Err("cannot insert since field".to_string()),
501                    Delete(_) => return Err("cannot delete since field".to_string()),
502                }
503            }
504            if !diff_legacy_batches.is_empty() {
505                apply_diffs_spine(metrics, diff_legacy_batches, trace)?;
506                debug_assert_eq!(trace.validate(), Ok(()), "{:?}", trace);
507            }
508            None
509        };
510
511        if !structure_unchanged {
512            let flat = flat.get_or_insert_with(|| trace.flatten());
513            apply_diffs_map(
514                "hollow_batches",
515                diff_hollow_batches,
516                &mut flat.hollow_batches,
517            )?;
518            apply_diffs_map("spine_batches", diff_spine_batches, &mut flat.spine_batches)?;
519            apply_diffs_map("merges", diff_merges, &mut flat.merges)?;
520        }
521
522        if let Some(flat) = flat {
523            *trace = Trace::unflatten(flat)?;
524        }
525
526        // There's various sanity checks that this method could run (e.g. since,
527        // upper, seqno_since, etc don't regress or that diff.latest_rollup ==
528        // state.rollups.last()), are they a good idea? On one hand, I like
529        // sanity checks, other the other, one of the goals here is to keep
530        // apply logic as straightforward and unchanging as possible.
531        Ok(())
532    }
533}
534
535fn diff_field_single<T: PartialEq + Clone>(
536    from: &T,
537    to: &T,
538    diffs: &mut Vec<StateFieldDiff<(), T>>,
539) {
540    // This could use the `diff_field_sorted_iter(once(from), once(to), diffs)`
541    // general impl, but we just do the obvious thing.
542    if from != to {
543        diffs.push(StateFieldDiff {
544            key: (),
545            val: Update(from.clone(), to.clone()),
546        })
547    }
548}
549
550fn apply_diffs_single_option<X: PartialEq + Debug>(
551    name: &str,
552    diffs: Vec<StateFieldDiff<(), X>>,
553    single: &mut Option<X>,
554) -> Result<(), String> {
555    for diff in diffs {
556        apply_diff_single_option(name, diff, single)?;
557    }
558    Ok(())
559}
560
561fn apply_diff_single_option<X: PartialEq + Debug>(
562    name: &str,
563    diff: StateFieldDiff<(), X>,
564    single: &mut Option<X>,
565) -> Result<(), String> {
566    match diff.val {
567        Update(from, to) => {
568            if single.as_ref() != Some(&from) {
569                return Err(format!(
570                    "{} update didn't match: {:?} vs {:?}",
571                    name, single, &from
572                ));
573            }
574            *single = Some(to)
575        }
576        Insert(to) => {
577            if single.is_some() {
578                return Err(format!("{} insert found existing value", name));
579            }
580            *single = Some(to)
581        }
582        Delete(from) => {
583            if single.as_ref() != Some(&from) {
584                return Err(format!(
585                    "{} delete didn't match: {:?} vs {:?}",
586                    name, single, &from
587                ));
588            }
589            *single = None
590        }
591    }
592    Ok(())
593}
594
595fn apply_diffs_single<X: PartialEq + Debug>(
596    name: &str,
597    diffs: Vec<StateFieldDiff<(), X>>,
598    single: &mut X,
599) -> Result<(), String> {
600    for diff in diffs {
601        apply_diff_single(name, diff, single)?;
602    }
603    Ok(())
604}
605
606fn apply_diff_single<X: PartialEq + Debug>(
607    name: &str,
608    diff: StateFieldDiff<(), X>,
609    single: &mut X,
610) -> Result<(), String> {
611    match diff.val {
612        Update(from, to) => {
613            if single != &from {
614                return Err(format!(
615                    "{} update didn't match: {:?} vs {:?}",
616                    name, single, &from
617                ));
618            }
619            *single = to
620        }
621        Insert(_) => return Err(format!("cannot insert {} field", name)),
622        Delete(_) => return Err(format!("cannot delete {} field", name)),
623    }
624    Ok(())
625}
626
627// A hack to force apply a diff, making `single` equal to
628// the Update `to` value, ignoring a mismatch on `from`.
629// Used to migrate forward after writing down incorrect
630// diffs.
631//
632// TODO: delete this once `hostname` has zero mismatches
633fn force_apply_diffs_single<X: PartialEq + Debug>(
634    shard_id: &ShardId,
635    seqno: SeqNo,
636    name: &str,
637    diffs: Vec<StateFieldDiff<(), X>>,
638    single: &mut X,
639    metrics: &Metrics,
640) -> Result<(), String> {
641    for diff in diffs {
642        force_apply_diff_single(shard_id, seqno, name, diff, single, metrics)?;
643    }
644    Ok(())
645}
646
647fn force_apply_diff_single<X: PartialEq + Debug>(
648    shard_id: &ShardId,
649    seqno: SeqNo,
650    name: &str,
651    diff: StateFieldDiff<(), X>,
652    single: &mut X,
653    metrics: &Metrics,
654) -> Result<(), String> {
655    match diff.val {
656        Update(from, to) => {
657            if single != &from {
658                debug!(
659                    "{}: update didn't match: {:?} vs {:?}, continuing to force apply diff to {:?} for shard {} and seqno {}",
660                    name, single, &from, &to, shard_id, seqno
661                );
662                metrics.state.force_apply_hostname.inc();
663            }
664            *single = to
665        }
666        Insert(_) => return Err(format!("cannot insert {} field", name)),
667        Delete(_) => return Err(format!("cannot delete {} field", name)),
668    }
669    Ok(())
670}
671
672fn diff_field_sorted_iter<'a, K, V, IF, IT>(from: IF, to: IT, diffs: &mut Vec<StateFieldDiff<K, V>>)
673where
674    K: Ord + Clone + 'a,
675    V: PartialEq + Clone + 'a,
676    IF: IntoIterator<Item = (&'a K, &'a V)>,
677    IT: IntoIterator<Item = (&'a K, &'a V)>,
678{
679    let (mut from, mut to) = (from.into_iter(), to.into_iter());
680    let (mut f, mut t) = (from.next(), to.next());
681    loop {
682        match (f, t) {
683            (None, None) => break,
684            (Some((fk, fv)), Some((tk, tv))) => match fk.cmp(tk) {
685                Ordering::Less => {
686                    diffs.push(StateFieldDiff {
687                        key: fk.clone(),
688                        val: Delete(fv.clone()),
689                    });
690                    let f_next = from.next();
691                    debug_assert!(f_next.as_ref().map_or(true, |(fk_next, _)| fk_next > &fk));
692                    f = f_next;
693                }
694                Ordering::Greater => {
695                    diffs.push(StateFieldDiff {
696                        key: tk.clone(),
697                        val: Insert(tv.clone()),
698                    });
699                    let t_next = to.next();
700                    debug_assert!(t_next.as_ref().map_or(true, |(tk_next, _)| tk_next > &tk));
701                    t = t_next;
702                }
703                Ordering::Equal => {
704                    // TODO: regression test for this if, I missed it in the
705                    // original impl :)
706                    if fv != tv {
707                        diffs.push(StateFieldDiff {
708                            key: fk.clone(),
709                            val: Update(fv.clone(), tv.clone()),
710                        });
711                    }
712                    let f_next = from.next();
713                    debug_assert!(f_next.as_ref().map_or(true, |(fk_next, _)| fk_next > &fk));
714                    f = f_next;
715                    let t_next = to.next();
716                    debug_assert!(t_next.as_ref().map_or(true, |(tk_next, _)| tk_next > &tk));
717                    t = t_next;
718                }
719            },
720            (None, Some((tk, tv))) => {
721                diffs.push(StateFieldDiff {
722                    key: tk.clone(),
723                    val: Insert(tv.clone()),
724                });
725                let t_next = to.next();
726                debug_assert!(t_next.as_ref().map_or(true, |(tk_next, _)| tk_next > &tk));
727                t = t_next;
728            }
729            (Some((fk, fv)), None) => {
730                diffs.push(StateFieldDiff {
731                    key: fk.clone(),
732                    val: Delete(fv.clone()),
733                });
734                let f_next = from.next();
735                debug_assert!(f_next.as_ref().map_or(true, |(fk_next, _)| fk_next > &fk));
736                f = f_next;
737            }
738        }
739    }
740}
741
742fn apply_diffs_map<K: Ord, V: PartialEq + Debug>(
743    name: &str,
744    diffs: impl IntoIterator<Item = StateFieldDiff<K, V>>,
745    map: &mut BTreeMap<K, V>,
746) -> Result<(), String> {
747    for diff in diffs {
748        apply_diff_map(name, diff, map)?;
749    }
750    Ok(())
751}
752
753// This might leave state in an invalid (umm) state when returning an error. The
754// caller ultimately ends up panic'ing on error, but if that changes, we might
755// want to revisit this.
756fn apply_diff_map<K: Ord, V: PartialEq + Debug>(
757    name: &str,
758    diff: StateFieldDiff<K, V>,
759    map: &mut BTreeMap<K, V>,
760) -> Result<(), String> {
761    match diff.val {
762        Insert(to) => {
763            let prev = map.insert(diff.key, to);
764            if prev != None {
765                return Err(format!("{} insert found existing value: {:?}", name, prev));
766            }
767        }
768        Update(from, to) => {
769            let prev = map.insert(diff.key, to);
770            if prev.as_ref() != Some(&from) {
771                return Err(format!(
772                    "{} update didn't match: {:?} vs {:?}",
773                    name,
774                    prev,
775                    Some(from),
776                ));
777            }
778        }
779        Delete(from) => {
780            let prev = map.remove(&diff.key);
781            if prev.as_ref() != Some(&from) {
782                return Err(format!(
783                    "{} delete didn't match: {:?} vs {:?}",
784                    name,
785                    prev,
786                    Some(from),
787                ));
788            }
789        }
790    };
791    Ok(())
792}
793
794// This might leave state in an invalid (umm) state when returning an error. The
795// caller ultimately ends up panic'ing on error, but if that changes, we might
796// want to revisit this.
797fn apply_diffs_spine<T: Timestamp + Lattice>(
798    metrics: &Metrics,
799    mut diffs: Vec<StateFieldDiff<HollowBatch<T>, ()>>,
800    trace: &mut Trace<T>,
801) -> Result<(), String> {
802    // Another special case: sniff out a newly inserted batch (one whose lower
803    // lines up with the current upper) and handle that now. Then fall through
804    // to the rest of the handling on whatever is left.
805    if let Some(insert) = sniff_insert(&mut diffs, trace.upper()) {
806        // Ignore merge_reqs because whichever process generated this diff is
807        // assigned the work.
808        let () = trace.push_batch_no_merge_reqs(insert);
809        // If this insert was the only thing in diffs, then return now instead
810        // of falling through to the "no diffs" case in the match so we can inc
811        // the apply_spine_fast_path metric.
812        if diffs.is_empty() {
813            metrics.state.apply_spine_fast_path.inc();
814            return Ok(());
815        }
816    }
817
818    match &diffs[..] {
819        // Fast-path: no diffs.
820        [] => return Ok(()),
821
822        // Fast-path: batch insert with both new and most recent batch empty.
823        // Spine will happily merge these empty batches together without a call
824        // out to compaction.
825        [
826            StateFieldDiff {
827                key: del,
828                val: StateFieldValDiff::Delete(()),
829            },
830            StateFieldDiff {
831                key: ins,
832                val: StateFieldValDiff::Insert(()),
833            },
834        ] => {
835            if del.is_empty()
836                && ins.is_empty()
837                && del.desc.lower() == ins.desc.lower()
838                && PartialOrder::less_than(del.desc.upper(), ins.desc.upper())
839            {
840                // Ignore merge_reqs because whichever process generated this diff is
841                // assigned the work.
842                let () = trace.push_batch_no_merge_reqs(HollowBatch::empty(Description::new(
843                    del.desc.upper().clone(),
844                    ins.desc.upper().clone(),
845                    // `keys.len() == 0` for both `del` and `ins` means we
846                    // don't have to think about what the compaction
847                    // frontier is for these batches (nothing in them, so nothing could have been compacted.
848                    Antichain::from_elem(T::minimum()),
849                )));
850                metrics.state.apply_spine_fast_path.inc();
851                return Ok(());
852            }
853        }
854        // Fall-through
855        _ => {}
856    }
857
858    // Fast-path: compaction
859    if let Some((_inputs, output)) = sniff_compaction(&diffs) {
860        let res = FueledMergeRes { output };
861        // We can't predict how spine will arrange the batches when it's
862        // hydrated. This means that something that is maintaining a Spine
863        // starting at some seqno may not exactly match something else
864        // maintaining the same spine starting at a different seqno. (Plus,
865        // maybe these aren't even on the same version of the code and we've
866        // changed the spine logic.) Because apply_merge_res is strict,
867        // we're not _guaranteed_ that we can apply a compaction response
868        // that was generated elsewhere. Most of the time we can, though, so
869        // count the good ones and fall back to the slow path below when we
870        // can't.
871        if trace.apply_merge_res(&res).applied() {
872            // Maybe return the replaced batches from apply_merge_res and verify
873            // that they match _inputs?
874            metrics.state.apply_spine_fast_path.inc();
875            return Ok(());
876        }
877
878        // Otherwise, try our lenient application of a compaction result.
879        let mut batches = Vec::new();
880        trace.map_batches(|b| batches.push(b.clone()));
881
882        match apply_compaction_lenient(metrics, batches, &res.output) {
883            Ok(batches) => {
884                let mut new_trace = Trace::default();
885                new_trace.roundtrip_structure = trace.roundtrip_structure;
886                new_trace.downgrade_since(trace.since());
887                for batch in batches {
888                    // Ignore merge_reqs because whichever process generated
889                    // this diff is assigned the work.
890                    let () = new_trace.push_batch_no_merge_reqs(batch.clone());
891                }
892                *trace = new_trace;
893                metrics.state.apply_spine_slow_path_lenient.inc();
894                return Ok(());
895            }
896            Err(err) => {
897                return Err(format!(
898                    "lenient compaction result apply unexpectedly failed: {}",
899                    err
900                ));
901            }
902        }
903    }
904
905    // Something complicated is going on, so reconstruct the Trace from scratch.
906    metrics.state.apply_spine_slow_path.inc();
907    debug!(
908        "apply_diffs_spine didn't hit a fast-path diffs={:?} trace={:?}",
909        diffs, trace
910    );
911
912    let batches = {
913        let mut batches = BTreeMap::new();
914        trace.map_batches(|b| assert_none!(batches.insert(b.clone(), ())));
915        apply_diffs_map("spine", diffs.clone(), &mut batches).map(|_ok| batches)
916    };
917
918    let batches = match batches {
919        Ok(batches) => batches,
920        Err(err) => {
921            metrics
922                .state
923                .apply_spine_slow_path_with_reconstruction
924                .inc();
925            debug!(
926                "apply_diffs_spines could not apply diffs directly to existing trace batches: {}. diffs={:?} trace={:?}",
927                err, diffs, trace
928            );
929            // if we couldn't apply our diffs directly to our trace's batches, we can
930            // try one more trick: reconstruct a new spine with our existing batches,
931            // in an attempt to create different merges than we currently have. then,
932            // we can try to apply our diffs on top of these new (potentially) merged
933            // batches.
934            let mut reconstructed_spine = Trace::default();
935            reconstructed_spine.roundtrip_structure = trace.roundtrip_structure;
936            trace.map_batches(|b| {
937                // Ignore merge_reqs because whichever process generated this
938                // diff is assigned the work.
939                let () = reconstructed_spine.push_batch_no_merge_reqs(b.clone());
940            });
941
942            let mut batches = BTreeMap::new();
943            reconstructed_spine.map_batches(|b| assert_none!(batches.insert(b.clone(), ())));
944            apply_diffs_map("spine", diffs, &mut batches)?;
945            batches
946        }
947    };
948
949    let mut new_trace = Trace::default();
950    new_trace.roundtrip_structure = trace.roundtrip_structure;
951    new_trace.downgrade_since(trace.since());
952    for (batch, ()) in batches {
953        // Ignore merge_reqs because whichever process generated this diff is
954        // assigned the work.
955        let () = new_trace.push_batch_no_merge_reqs(batch);
956    }
957    *trace = new_trace;
958    Ok(())
959}
960
961fn sniff_insert<T: Timestamp + Lattice>(
962    diffs: &mut Vec<StateFieldDiff<HollowBatch<T>, ()>>,
963    upper: &Antichain<T>,
964) -> Option<HollowBatch<T>> {
965    for idx in 0..diffs.len() {
966        match &diffs[idx] {
967            StateFieldDiff {
968                key,
969                val: StateFieldValDiff::Insert(()),
970            } if key.desc.lower() == upper => return Some(diffs.remove(idx).key),
971            _ => continue,
972        }
973    }
974    None
975}
976
977// TODO: Instead of trying to sniff out a compaction from diffs, should we just
978// be explicit?
979fn sniff_compaction<'a, T: Timestamp + Lattice>(
980    diffs: &'a [StateFieldDiff<HollowBatch<T>, ()>],
981) -> Option<(Vec<&'a HollowBatch<T>>, HollowBatch<T>)> {
982    // Compaction always produces exactly one output batch (with possibly many
983    // parts, but we get one Insert for the whole batch.
984    let mut inserts = diffs.iter().flat_map(|x| match x.val {
985        StateFieldValDiff::Insert(()) => Some(&x.key),
986        _ => None,
987    });
988    let compaction_output = match inserts.next() {
989        Some(x) => x,
990        None => return None,
991    };
992    if let Some(_) = inserts.next() {
993        return None;
994    }
995
996    // Grab all deletes and sanity check that there are no updates.
997    let mut compaction_inputs = Vec::with_capacity(diffs.len() - 1);
998    for diff in diffs.iter() {
999        match diff.val {
1000            StateFieldValDiff::Delete(()) => {
1001                compaction_inputs.push(&diff.key);
1002            }
1003            StateFieldValDiff::Insert(()) => {}
1004            StateFieldValDiff::Update((), ()) => {
1005                // Fall through to let the general case create the error
1006                // message.
1007                return None;
1008            }
1009        }
1010    }
1011
1012    Some((compaction_inputs, compaction_output.clone()))
1013}
1014
1015/// Apply a compaction diff that doesn't exactly line up with the set of
1016/// HollowBatches.
1017///
1018/// Because of the way Spine internally optimizes only _some_ empty batches
1019/// (immediately merges them in), we can end up in a situation where a
1020/// compaction res applied on another copy of state, but when we replay all of
1021/// the state diffs against a new Spine locally, it merges empty batches
1022/// differently in-mem and we can't exactly apply the compaction diff. Example:
1023///
1024/// - compact: [1,2),[2,3) -> [1,3)
1025/// - this spine: [0,2),[2,3) (0,1 is empty)
1026///
1027/// Ideally, we'd figure out a way to avoid this, but nothing immediately comes
1028/// to mind. In the meantime, force the application (otherwise the shard is
1029/// stuck and we can't do anything with it) by manually splitting the empty
1030/// batch back out. For the example above:
1031///
1032/// - [0,1),[1,3) (0,1 is empty)
1033///
1034/// This can only happen when the batch needing to be split is empty, so error
1035/// out if it isn't because that means something unexpected is going on.
1036///
1037/// TODO: This implementation is certainly not correct if T is actually only
1038/// partially ordered.
1039fn apply_compaction_lenient<'a, T: Timestamp + Lattice>(
1040    metrics: &Metrics,
1041    mut trace: Vec<HollowBatch<T>>,
1042    replacement: &'a HollowBatch<T>,
1043) -> Result<Vec<HollowBatch<T>>, String> {
1044    let mut overlapping_batches = Vec::new();
1045    trace.retain(|b| {
1046        let before_replacement = PartialOrder::less_equal(b.desc.upper(), replacement.desc.lower());
1047        let after_replacement = PartialOrder::less_equal(replacement.desc.upper(), b.desc.lower());
1048        let overlaps_replacement = !(before_replacement || after_replacement);
1049        if overlaps_replacement {
1050            overlapping_batches.push(b.clone());
1051            false
1052        } else {
1053            true
1054        }
1055    });
1056
1057    {
1058        let first_overlapping_batch = match overlapping_batches.first() {
1059            Some(x) => x,
1060            None => return Err("replacement didn't overlap any batches".into()),
1061        };
1062        if PartialOrder::less_than(
1063            first_overlapping_batch.desc.lower(),
1064            replacement.desc.lower(),
1065        ) {
1066            if first_overlapping_batch.len > 0 {
1067                return Err(format!(
1068                    "overlapping batch was unexpectedly non-empty: {:?}",
1069                    first_overlapping_batch
1070                ));
1071            }
1072            let desc = Description::new(
1073                first_overlapping_batch.desc.lower().clone(),
1074                replacement.desc.lower().clone(),
1075                first_overlapping_batch.desc.since().clone(),
1076            );
1077            trace.push(HollowBatch::empty(desc));
1078            metrics.state.apply_spine_slow_path_lenient_adjustment.inc();
1079        }
1080    }
1081
1082    {
1083        let last_overlapping_batch = match overlapping_batches.last() {
1084            Some(x) => x,
1085            None => return Err("replacement didn't overlap any batches".into()),
1086        };
1087        if PartialOrder::less_than(
1088            replacement.desc.upper(),
1089            last_overlapping_batch.desc.upper(),
1090        ) {
1091            if last_overlapping_batch.len > 0 {
1092                return Err(format!(
1093                    "overlapping batch was unexpectedly non-empty: {:?}",
1094                    last_overlapping_batch
1095                ));
1096            }
1097            let desc = Description::new(
1098                replacement.desc.upper().clone(),
1099                last_overlapping_batch.desc.upper().clone(),
1100                last_overlapping_batch.desc.since().clone(),
1101            );
1102            trace.push(HollowBatch::empty(desc));
1103            metrics.state.apply_spine_slow_path_lenient_adjustment.inc();
1104        }
1105    }
1106    trace.push(replacement.clone());
1107
1108    // We just inserted stuff at the end, so re-sort them into place.
1109    trace.sort_by(|a, b| a.desc.lower().elements().cmp(b.desc.lower().elements()));
1110
1111    // This impl is a touch complex, so sanity check our work.
1112    let mut expected_lower = &Antichain::from_elem(T::minimum());
1113    for b in trace.iter() {
1114        if b.desc.lower() != expected_lower {
1115            return Err(format!(
1116                "lower {:?} did not match expected {:?}: {:?}",
1117                b.desc.lower(),
1118                expected_lower,
1119                trace
1120            ));
1121        }
1122        expected_lower = b.desc.upper();
1123    }
1124    Ok(trace)
1125}
1126
1127/// A type that facilitates the proto encoding of a [`ProtoStateFieldDiffs`]
1128///
1129/// [`ProtoStateFieldDiffs`] is a columnar encoding of [`StateFieldDiff`]s, see
1130/// its doc comment for more info. The underlying buffer for a [`ProtoStateFieldDiffs`]
1131/// is a [`Bytes`] struct, which is an immutable, shared, reference counted,
1132/// buffer of data. Using a [`Bytes`] struct is a very efficient way to manage data
1133/// becuase multiple [`Bytes`] can reference different parts of the same underlying
1134/// portion of memory. See its doc comment for more info.
1135///
1136/// A [`ProtoStateFieldDiffsWriter`] maintains a mutable, unique, data buffer, i.e.
1137/// a [`BytesMut`], which we use when encoding a [`StateFieldDiff`]. And when
1138/// finished encoding, we convert it into a [`ProtoStateFieldDiffs`] by "freezing" the
1139/// underlying buffer, converting it into a [`Bytes`] struct, so it can be shared.
1140///
1141/// [`Bytes`]: bytes::Bytes
1142#[derive(Debug)]
1143pub struct ProtoStateFieldDiffsWriter {
1144    data_buf: BytesMut,
1145    proto: ProtoStateFieldDiffs,
1146}
1147
1148impl ProtoStateFieldDiffsWriter {
1149    /// Record a [`ProtoStateField`] for our columnar encoding.
1150    pub fn push_field(&mut self, field: ProtoStateField) {
1151        self.proto.fields.push(i32::from(field));
1152    }
1153
1154    /// Record a [`ProtoStateFieldDiffType`] for our columnar encoding.
1155    pub fn push_diff_type(&mut self, diff_type: ProtoStateFieldDiffType) {
1156        self.proto.diff_types.push(i32::from(diff_type));
1157    }
1158
1159    /// Encode a message for our columnar encoding.
1160    pub fn encode_proto<M: prost::Message>(&mut self, msg: &M) {
1161        let len_before = self.data_buf.len();
1162        self.data_buf.reserve(msg.encoded_len());
1163
1164        // Note: we use `encode_raw` as opposed to `encode` because all `encode` does is
1165        // check to make sure there's enough bytes in the buffer to fit our message
1166        // which we know there are because we just reserved the space. When benchmarking
1167        // `encode_raw` does offer a slight performance improvement over `encode`.
1168        msg.encode_raw(&mut self.data_buf);
1169
1170        // Record exactly how many bytes were written.
1171        let written_len = self.data_buf.len() - len_before;
1172        self.proto.data_lens.push(u64::cast_from(written_len));
1173    }
1174
1175    pub fn into_proto(self) -> ProtoStateFieldDiffs {
1176        let ProtoStateFieldDiffsWriter {
1177            data_buf,
1178            mut proto,
1179        } = self;
1180
1181        // Assert we didn't write into the proto's data_bytes field
1182        assert!(proto.data_bytes.is_empty());
1183
1184        // Move our buffer into the proto
1185        let data_bytes = data_buf.freeze();
1186        proto.data_bytes = data_bytes;
1187
1188        proto
1189    }
1190}
1191
1192impl ProtoStateFieldDiffs {
1193    pub fn into_writer(mut self) -> ProtoStateFieldDiffsWriter {
1194        // Create a new buffer which we'll encode data into.
1195        let mut data_buf = BytesMut::with_capacity(self.data_bytes.len());
1196
1197        // Take our existing data, and copy it into our buffer.
1198        let existing_data = std::mem::take(&mut self.data_bytes);
1199        data_buf.extend(existing_data);
1200
1201        ProtoStateFieldDiffsWriter {
1202            data_buf,
1203            proto: self,
1204        }
1205    }
1206
1207    pub fn iter<'a>(&'a self) -> ProtoStateFieldDiffsIter<'a> {
1208        let len = self.fields.len();
1209        assert_eq!(self.diff_types.len(), len);
1210
1211        ProtoStateFieldDiffsIter {
1212            len,
1213            diff_idx: 0,
1214            data_idx: 0,
1215            data_offset: 0,
1216            diffs: self,
1217        }
1218    }
1219
1220    pub fn validate(&self) -> Result<(), String> {
1221        if self.fields.len() != self.diff_types.len() {
1222            return Err(format!(
1223                "fields {} and diff_types {} lengths disagree",
1224                self.fields.len(),
1225                self.diff_types.len()
1226            ));
1227        }
1228
1229        let mut expected_data_slices = 0;
1230        for diff_type in self.diff_types.iter() {
1231            // We expect one for the key.
1232            expected_data_slices += 1;
1233            // And 1 or 2 for val depending on the diff type.
1234            match ProtoStateFieldDiffType::try_from(*diff_type) {
1235                Ok(ProtoStateFieldDiffType::Insert) => expected_data_slices += 1,
1236                Ok(ProtoStateFieldDiffType::Update) => expected_data_slices += 2,
1237                Ok(ProtoStateFieldDiffType::Delete) => expected_data_slices += 1,
1238                Err(_) => return Err(format!("unknown diff_type {}", diff_type)),
1239            }
1240        }
1241        if expected_data_slices != self.data_lens.len() {
1242            return Err(format!(
1243                "expected {} data slices got {}",
1244                expected_data_slices,
1245                self.data_lens.len()
1246            ));
1247        }
1248
1249        let expected_data_bytes = usize::cast_from(self.data_lens.iter().copied().sum::<u64>());
1250        if expected_data_bytes != self.data_bytes.len() {
1251            return Err(format!(
1252                "expected {} data bytes got {}",
1253                expected_data_bytes,
1254                self.data_bytes.len()
1255            ));
1256        }
1257
1258        Ok(())
1259    }
1260}
1261
1262#[derive(Debug)]
1263pub struct ProtoStateFieldDiff<'a> {
1264    pub key: &'a [u8],
1265    pub diff_type: ProtoStateFieldDiffType,
1266    pub from: &'a [u8],
1267    pub to: &'a [u8],
1268}
1269
1270pub struct ProtoStateFieldDiffsIter<'a> {
1271    len: usize,
1272    diff_idx: usize,
1273    data_idx: usize,
1274    data_offset: usize,
1275    diffs: &'a ProtoStateFieldDiffs,
1276}
1277
1278impl<'a> Iterator for ProtoStateFieldDiffsIter<'a> {
1279    type Item = Result<(ProtoStateField, ProtoStateFieldDiff<'a>), TryFromProtoError>;
1280
1281    fn next(&mut self) -> Option<Self::Item> {
1282        if self.diff_idx >= self.len {
1283            return None;
1284        }
1285        let mut next_data = || {
1286            let start = self.data_offset;
1287            let end = start + usize::cast_from(self.diffs.data_lens[self.data_idx]);
1288            let data = &self.diffs.data_bytes[start..end];
1289            self.data_idx += 1;
1290            self.data_offset = end;
1291            data
1292        };
1293        let field = match ProtoStateField::try_from(self.diffs.fields[self.diff_idx]) {
1294            Ok(x) => x,
1295            Err(_) => {
1296                return Some(Err(TryFromProtoError::unknown_enum_variant(format!(
1297                    "ProtoStateField({})",
1298                    self.diffs.fields[self.diff_idx]
1299                ))));
1300            }
1301        };
1302        let diff_type =
1303            match ProtoStateFieldDiffType::try_from(self.diffs.diff_types[self.diff_idx]) {
1304                Ok(x) => x,
1305                Err(_) => {
1306                    return Some(Err(TryFromProtoError::unknown_enum_variant(format!(
1307                        "ProtoStateFieldDiffType({})",
1308                        self.diffs.diff_types[self.diff_idx]
1309                    ))));
1310                }
1311            };
1312        let key = next_data();
1313        let (from, to): (&[u8], &[u8]) = match diff_type {
1314            ProtoStateFieldDiffType::Insert => (&[], next_data()),
1315            ProtoStateFieldDiffType::Update => (next_data(), next_data()),
1316            ProtoStateFieldDiffType::Delete => (next_data(), &[]),
1317        };
1318        let diff = ProtoStateFieldDiff {
1319            key,
1320            diff_type,
1321            from,
1322            to,
1323        };
1324        self.diff_idx += 1;
1325        Some(Ok((field, diff)))
1326    }
1327}
1328
1329#[cfg(test)]
1330mod tests {
1331    use semver::Version;
1332    use std::ops::ControlFlow::Continue;
1333
1334    use crate::internal::paths::{PartId, PartialBatchKey, RollupId, WriterKey};
1335    use mz_ore::metrics::MetricsRegistry;
1336
1337    use crate::ShardId;
1338    use crate::internal::state::TypedState;
1339
1340    use super::*;
1341
1342    /// Model a situation where a "leader" is constantly making changes to its state, and a "follower"
1343    /// is applying those changes as diffs.
1344    #[mz_ore::test]
1345    #[cfg_attr(miri, ignore)] // too slow
1346    fn test_state_sync() {
1347        use proptest::prelude::*;
1348
1349        #[derive(Debug, Clone)]
1350        enum Action {
1351            /// Append a (non)empty batch to the shard that covers the given length of time.
1352            Append { empty: bool, time_delta: u64 },
1353            /// Apply the Nth compaction request we've received to the shard state.
1354            Compact { req: usize },
1355        }
1356
1357        let action_gen: BoxedStrategy<Action> = {
1358            prop::strategy::Union::new([
1359                (any::<bool>(), 1u64..10u64)
1360                    .prop_map(|(empty, time_delta)| Action::Append { empty, time_delta })
1361                    .boxed(),
1362                (0usize..10usize)
1363                    .prop_map(|req| Action::Compact { req })
1364                    .boxed(),
1365            ])
1366            .boxed()
1367        };
1368
1369        fn run(actions: Vec<(Action, bool)>, metrics: &Metrics) {
1370            let version = Version::new(0, 100, 0);
1371            let writer_key = WriterKey::Version(version.to_string());
1372            let id = ShardId::new();
1373            let hostname = "computer";
1374            let typed: TypedState<String, (), u64, i64> =
1375                TypedState::new(version, id, hostname.to_string(), 0);
1376            let mut leader = typed.state;
1377
1378            let seqno = SeqNo::minimum();
1379            let mut lower = 0u64;
1380            let mut merge_reqs = vec![];
1381
1382            leader.collections.rollups.insert(
1383                seqno,
1384                HollowRollup {
1385                    key: PartialRollupKey::new(seqno, &RollupId::new()),
1386                    encoded_size_bytes: None,
1387                },
1388            );
1389            leader.collections.trace.roundtrip_structure = false;
1390            let mut follower = leader.clone();
1391
1392            for (action, roundtrip_structure) in actions {
1393                // Apply the given action and the new roundtrip_structure setting and take a diff.
1394                let mut old_leader = leader.clone();
1395                match action {
1396                    Action::Append { empty, time_delta } => {
1397                        let upper = lower + time_delta;
1398                        let key = if empty {
1399                            None
1400                        } else {
1401                            let id = PartId::new();
1402                            Some(PartialBatchKey::new(&writer_key, &id))
1403                        };
1404
1405                        let keys = key.as_ref().map(|k| k.0.as_str());
1406                        let reqs = leader.collections.trace.push_batch(
1407                            crate::internal::state::tests::hollow(
1408                                lower,
1409                                upper,
1410                                keys.as_slice(),
1411                                if empty { 0 } else { 1 },
1412                            ),
1413                        );
1414                        merge_reqs.extend(reqs);
1415                        lower = upper;
1416                    }
1417                    Action::Compact { req } => {
1418                        if !merge_reqs.is_empty() {
1419                            let req = merge_reqs.remove(req.min(merge_reqs.len() - 1));
1420                            let len = req.inputs.iter().map(|p| p.batch.len).sum();
1421                            let parts = req
1422                                .inputs
1423                                .into_iter()
1424                                .flat_map(|p| p.batch.parts.clone())
1425                                .collect();
1426                            let output = HollowBatch::new_run(req.desc, parts, len);
1427                            leader
1428                                .collections
1429                                .trace
1430                                .apply_merge_res(&FueledMergeRes { output });
1431                        }
1432                    }
1433                }
1434                leader.collections.trace.roundtrip_structure = roundtrip_structure;
1435                leader.seqno.0 += 1;
1436                let diff = StateDiff::from_diff(&old_leader, &leader);
1437
1438                // Validate that the diff applies to both the previous state (also checked in
1439                // debug asserts) and our follower that's only synchronized via diffs.
1440                old_leader
1441                    .apply_diff(metrics, diff.clone())
1442                    .expect("diff applies to the old version of the leader state");
1443                follower
1444                    .apply_diff(metrics, diff.clone())
1445                    .expect("diff applies to the synced version of the follower state");
1446
1447                // TODO: once spine structure is roundtripped through diffs, assert that the follower
1448                // has the same batches etc. as the leader does.
1449            }
1450        }
1451
1452        let config = PersistConfig::new_for_tests();
1453        let metrics_registry = MetricsRegistry::new();
1454        let metrics: Metrics = Metrics::new(&config, &metrics_registry);
1455
1456        proptest!(|(actions in prop::collection::vec((action_gen, any::<bool>()), 1..20))| {
1457            run(actions, &metrics)
1458        })
1459    }
1460
1461    // Regression test for the apply_diffs_spine special case that sniffs out an
1462    // insert, applies it, and then lets the remaining diffs (if any) fall
1463    // through to the rest of the code. See database-issues#4431.
1464    #[mz_ore::test]
1465    fn regression_15493_sniff_insert() {
1466        fn hb(lower: u64, upper: u64, len: usize) -> HollowBatch<u64> {
1467            HollowBatch::new_run(
1468                Description::new(
1469                    Antichain::from_elem(lower),
1470                    Antichain::from_elem(upper),
1471                    Antichain::from_elem(0),
1472                ),
1473                Vec::new(),
1474                len,
1475            )
1476        }
1477
1478        // The bug handled here is essentially a set of batches that look like
1479        // the pattern matched by `apply_lenient` _plus_ an insert. In
1480        // apply_diffs_spine, we use `sniff_insert` to steal the insert out of
1481        // the diffs and fall back to the rest of the logic to handle the
1482        // remaining diffs.
1483        //
1484        // Concretely, something like (the numbers are truncated versions of the
1485        // actual bug posted in the issue):
1486        // - spine: [0][7094664]0, [7094664][7185234]100
1487        // - diffs: [0][6805359]0 del, [6805359][7083793]0 del, [0][7083793]0 ins,
1488        //   [7185234][7185859]20 ins
1489        //
1490        // Where this allows us to handle the [7185234,7185859) and then
1491        // apply_lenient handles splitting up [0,7094664) so we can apply the
1492        // [0,6805359)+[6805359,7083793)->[0,7083793) swap.
1493
1494        let batches_before = vec![hb(0, 7094664, 0), hb(7094664, 7185234, 100)];
1495
1496        let diffs = vec![
1497            StateFieldDiff {
1498                key: hb(0, 6805359, 0),
1499                val: StateFieldValDiff::Delete(()),
1500            },
1501            StateFieldDiff {
1502                key: hb(6805359, 7083793, 0),
1503                val: StateFieldValDiff::Delete(()),
1504            },
1505            StateFieldDiff {
1506                key: hb(0, 7083793, 0),
1507                val: StateFieldValDiff::Insert(()),
1508            },
1509            StateFieldDiff {
1510                key: hb(7185234, 7185859, 20),
1511                val: StateFieldValDiff::Insert(()),
1512            },
1513        ];
1514
1515        // Ideally this first batch would be [0][7083793], [7083793,7094664]
1516        // here because `apply_lenient` splits it out, but when `apply_lenient`
1517        // reconstructs the trace, Spine happens to (deterministically) collapse
1518        // them back together. The main value of this test is that the
1519        // `apply_diffs_spine` call below doesn't return an Err, so don't worry
1520        // too much about this, it's just a sanity check.
1521        let batches_after = vec![
1522            hb(0, 7094664, 0),
1523            hb(7094664, 7185234, 100),
1524            hb(7185234, 7185859, 20),
1525        ];
1526
1527        let cfg = PersistConfig::new_for_tests();
1528        let state = TypedState::<(), (), u64, i64>::new(
1529            cfg.build_version.clone(),
1530            ShardId::new(),
1531            cfg.hostname.clone(),
1532            (cfg.now)(),
1533        );
1534        let state = state.clone_apply(&cfg, &mut |_seqno, _cfg, state| {
1535            for b in batches_before.iter() {
1536                let _merge_reqs = state.trace.push_batch(b.clone());
1537            }
1538            Continue::<(), ()>(())
1539        });
1540        let mut state = match state {
1541            Continue((_, x)) => x,
1542            _ => unreachable!(),
1543        };
1544
1545        let metrics = Metrics::new(&PersistConfig::new_for_tests(), &MetricsRegistry::new());
1546        assert_eq!(
1547            apply_diffs_spine(&metrics, diffs, &mut state.collections.trace),
1548            Ok(())
1549        );
1550
1551        let mut actual = Vec::new();
1552        state
1553            .collections
1554            .trace
1555            .map_batches(|b| actual.push(b.clone()));
1556        assert_eq!(actual, batches_after);
1557    }
1558
1559    #[mz_ore::test]
1560    #[cfg_attr(miri, ignore)] // too slow
1561    fn apply_lenient() {
1562        #[track_caller]
1563        fn testcase(
1564            replacement: (u64, u64, u64, usize),
1565            spine: &[(u64, u64, u64, usize)],
1566            expected: Result<&[(u64, u64, u64, usize)], &str>,
1567        ) {
1568            fn batch(x: &(u64, u64, u64, usize)) -> HollowBatch<u64> {
1569                let (lower, upper, since, len) = x;
1570                let desc = Description::new(
1571                    Antichain::from_elem(*lower),
1572                    Antichain::from_elem(*upper),
1573                    Antichain::from_elem(*since),
1574                );
1575                HollowBatch::new_run(desc, Vec::new(), *len)
1576            }
1577            let replacement = batch(&replacement);
1578            let batches = spine.iter().map(batch).collect::<Vec<_>>();
1579
1580            let metrics = Metrics::new(&PersistConfig::new_for_tests(), &MetricsRegistry::new());
1581            let actual = apply_compaction_lenient(&metrics, batches, &replacement);
1582            let expected = match expected {
1583                Ok(batches) => Ok(batches.iter().map(batch).collect::<Vec<_>>()),
1584                Err(err) => Err(err.to_owned()),
1585            };
1586            assert_eq!(actual, expected);
1587        }
1588
1589        // Exact swap of N batches
1590        testcase(
1591            (0, 3, 0, 100),
1592            &[(0, 1, 0, 0), (1, 2, 0, 0), (2, 3, 0, 0)],
1593            Ok(&[(0, 3, 0, 100)]),
1594        );
1595
1596        // Swap out the middle of a batch
1597        testcase(
1598            (1, 2, 0, 100),
1599            &[(0, 3, 0, 0)],
1600            Ok(&[(0, 1, 0, 0), (1, 2, 0, 100), (2, 3, 0, 0)]),
1601        );
1602
1603        // Split batch at replacement lower
1604        testcase(
1605            (2, 4, 0, 100),
1606            &[(0, 3, 0, 0), (3, 4, 0, 0)],
1607            Ok(&[(0, 2, 0, 0), (2, 4, 0, 100)]),
1608        );
1609
1610        // Err: split batch at replacement lower not empty
1611        testcase(
1612            (2, 4, 0, 100),
1613            &[(0, 3, 0, 1), (3, 4, 0, 0)],
1614            Err(
1615                "overlapping batch was unexpectedly non-empty: HollowBatch { desc: ([0], [3], [0]), parts: [], len: 1, runs: [], run_meta: [] }",
1616            ),
1617        );
1618
1619        // Split batch at replacement lower (untouched batch before the split one)
1620        testcase(
1621            (2, 4, 0, 100),
1622            &[(0, 1, 0, 0), (1, 3, 0, 0), (3, 4, 0, 0)],
1623            Ok(&[(0, 1, 0, 0), (1, 2, 0, 0), (2, 4, 0, 100)]),
1624        );
1625
1626        // Split batch at replacement lower (since is preserved)
1627        testcase(
1628            (2, 4, 0, 100),
1629            &[(0, 3, 200, 0), (3, 4, 0, 0)],
1630            Ok(&[(0, 2, 200, 0), (2, 4, 0, 100)]),
1631        );
1632
1633        // Split batch at replacement upper
1634        testcase(
1635            (0, 2, 0, 100),
1636            &[(0, 1, 0, 0), (1, 4, 0, 0)],
1637            Ok(&[(0, 2, 0, 100), (2, 4, 0, 0)]),
1638        );
1639
1640        // Err: split batch at replacement upper not empty
1641        testcase(
1642            (0, 2, 0, 100),
1643            &[(0, 1, 0, 0), (1, 4, 0, 1)],
1644            Err(
1645                "overlapping batch was unexpectedly non-empty: HollowBatch { desc: ([1], [4], [0]), parts: [], len: 1, runs: [], run_meta: [] }",
1646            ),
1647        );
1648
1649        // Split batch at replacement upper (untouched batch after the split one)
1650        testcase(
1651            (0, 2, 0, 100),
1652            &[(0, 1, 0, 0), (1, 3, 0, 0), (3, 4, 0, 0)],
1653            Ok(&[(0, 2, 0, 100), (2, 3, 0, 0), (3, 4, 0, 0)]),
1654        );
1655
1656        // Split batch at replacement upper (since is preserved)
1657        testcase(
1658            (0, 2, 0, 100),
1659            &[(0, 1, 0, 0), (1, 4, 200, 0)],
1660            Ok(&[(0, 2, 0, 100), (2, 4, 200, 0)]),
1661        );
1662
1663        // Split batch at replacement lower and upper
1664        testcase(
1665            (2, 6, 0, 100),
1666            &[(0, 3, 0, 0), (3, 5, 0, 0), (5, 8, 0, 0)],
1667            Ok(&[(0, 2, 0, 0), (2, 6, 0, 100), (6, 8, 0, 0)]),
1668        );
1669
1670        // Replacement doesn't overlap (after)
1671        testcase(
1672            (2, 3, 0, 100),
1673            &[(0, 1, 0, 0)],
1674            Err("replacement didn't overlap any batches"),
1675        );
1676
1677        // Replacement doesn't overlap (before, though this would never happen in practice)
1678        testcase(
1679            (2, 3, 0, 100),
1680            &[(4, 5, 0, 0)],
1681            Err("replacement didn't overlap any batches"),
1682        );
1683    }
1684}