1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
// Copyright Materialize, Inc. and contributors. All rights reserved.
//
// Use of this software is governed by the Business Source License
// included in the LICENSE file.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0.

//! An `UpsertStateBackend` that starts in memory and spills to RocksDB
//! when the total size passes some threshold.

use std::sync::Arc;

use mz_ore::metrics::DeleteOnDropGauge;
use prometheus::core::AtomicU64;
use serde::{de::DeserializeOwned, Serialize};

use super::memory::InMemoryHashMap;
use super::rocksdb::RocksDB;
use super::types::{
    GetStats, MergeStats, MergeValue, PutStats, PutValue, StateValue, UpsertStateBackend,
    UpsertValueAndSize,
};
use super::UpsertKey;

pub enum BackendType<O> {
    InMemory(InMemoryHashMap<O>),
    RocksDb(RocksDB<O>),
}

pub struct AutoSpillBackend<O, F> {
    backend_type: BackendType<O>,
    auto_spill_threshold_bytes: usize,
    rocksdb_autospill_in_use: Arc<DeleteOnDropGauge<'static, AtomicU64, Vec<String>>>,
    rocksdb_init_fn: Option<F>,
}

impl<O, F, Fut> AutoSpillBackend<O, F>
where
    O: Clone + Send + Sync + Serialize + DeserializeOwned + 'static,
    F: FnOnce() -> Fut + 'static,
    Fut: std::future::Future<Output = RocksDB<O>>,
{
    pub(crate) fn new(
        rocksdb_init_fn: F,
        auto_spill_threshold_bytes: usize,
        rocksdb_autospill_in_use: Arc<DeleteOnDropGauge<'static, AtomicU64, Vec<String>>>,
    ) -> Self {
        // Initializing the metric to 0, to reflect in memory hash map is being used
        rocksdb_autospill_in_use.set(0);
        Self {
            backend_type: BackendType::InMemory(InMemoryHashMap::default()),
            rocksdb_init_fn: Some(rocksdb_init_fn),
            auto_spill_threshold_bytes,
            rocksdb_autospill_in_use,
        }
    }
}

#[async_trait::async_trait(?Send)]
impl<O, F, Fut> UpsertStateBackend<O> for AutoSpillBackend<O, F>
where
    O: Clone + Send + Sync + Serialize + DeserializeOwned + 'static,
    F: FnOnce() -> Fut + 'static,
    Fut: std::future::Future<Output = RocksDB<O>>,
{
    fn supports_merge(&self) -> bool {
        // We only support merge if the backend supports it; the in-memory backend does not
        // and the rocksdb backend does if configure to do so.
        match &self.backend_type {
            BackendType::InMemory(_) => false,
            BackendType::RocksDb(backend) => backend.supports_merge(),
        }
    }

    async fn multi_put<P>(&mut self, puts: P) -> Result<PutStats, anyhow::Error>
    where
        P: IntoIterator<Item = (UpsertKey, PutValue<StateValue<O>>)>,
    {
        // Note that we never revert back to memory if the size shrinks below the threshold.
        // That case is considered rare and not worth the complexity.
        match &mut self.backend_type {
            BackendType::InMemory(map) => {
                let mut put_stats = map.multi_put(puts).await?;
                let in_memory_size: usize = map
                    .current_size()
                    .try_into()
                    .expect("unexpected error while casting");
                if in_memory_size > self.auto_spill_threshold_bytes {
                    tracing::info!("spilling to disk for upsert");
                    let mut rocksdb_backend =
                        self.rocksdb_init_fn
                            .take()
                            .expect("Can only initialize once")()
                        .await;

                    let (last_known_size, new_puts) = map.drain();
                    let new_puts = new_puts.map(|(k, v)| {
                        (
                            k,
                            PutValue {
                                value: Some(v),
                                previous_value_metadata: None,
                            },
                        )
                    });

                    let rocksdb_stats = rocksdb_backend.multi_put(new_puts).await?;
                    // Adjusting the sizes as the value sizes in rocksdb could be different than in memory
                    put_stats.size_diff += rocksdb_stats.size_diff;
                    put_stats.size_diff -= last_known_size;
                    // Setting backend to rocksdb
                    self.backend_type = BackendType::RocksDb(rocksdb_backend);
                    // Switching metric to 1 for rocksdb
                    self.rocksdb_autospill_in_use.set(1);
                }
                Ok(put_stats)
            }
            BackendType::RocksDb(rocks_db) => rocks_db.multi_put(puts).await,
        }
    }

    async fn multi_merge<M>(&mut self, merges: M) -> Result<MergeStats, anyhow::Error>
    where
        M: IntoIterator<Item = (UpsertKey, MergeValue<StateValue<O>>)>,
    {
        match &mut self.backend_type {
            BackendType::InMemory(_) => {
                anyhow::bail!("InMemoryHashMap does not support merging");
            }
            BackendType::RocksDb(rocks_db) => rocks_db.multi_merge(merges).await,
        }
    }

    async fn multi_get<'r, G, R>(
        &mut self,
        gets: G,
        results_out: R,
    ) -> Result<GetStats, anyhow::Error>
    where
        G: IntoIterator<Item = UpsertKey>,
        R: IntoIterator<Item = &'r mut UpsertValueAndSize<O>>,
        O: 'r,
    {
        match &mut self.backend_type {
            BackendType::InMemory(in_memory_hash_map) => {
                in_memory_hash_map.multi_get(gets, results_out).await
            }
            BackendType::RocksDb(rocks_db) => rocks_db.multi_get(gets, results_out).await,
        }
    }
}