1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0
 */

use once_cell::sync::OnceCell;
use std::collections::HashMap;
use std::hash::Hash;
use std::sync::{Mutex, MutexGuard};

/// A data structure for persisting and sharing state between multiple clients.
///
/// Some state should be shared between multiple clients. For example, when creating multiple clients
/// for the same service, it's desirable to share a client rate limiter. This way, when one client
/// receives a throttling response, the other clients will be aware of it as well.
///
/// Whether clients share state is dependent on their partition key `K`. Going back to the client
/// rate limiter example, `K` would be a struct containing the name of the service as well as the
/// client's configured region, since receiving throttling responses in `us-east-1` shouldn't
/// throttle requests to the same service made in other regions.
///
/// Values stored in a `StaticPartitionMap` will be cloned whenever they are requested. Values must
/// be initialized before they can be retrieved, and the `StaticPartitionMap::get_or_init` method is
/// how you can ensure this.
///
/// # Example
///
/// ```
///use std::sync::{Arc, Mutex};
/// use aws_smithy_runtime::static_partition_map::StaticPartitionMap;
///
/// // The shared state must be `Clone` and will be internally mutable. Deriving `Default` isn't
/// // necessary, but allows us to use the `StaticPartitionMap::get_or_init_default` method.
/// #[derive(Clone, Default)]
/// pub struct SomeSharedState {
///     inner: Arc<Mutex<Inner>>
/// }
///
/// #[derive(Default)]
/// struct Inner {
///     // Some shared state...
/// }
///
/// // `Clone`, `Hash`, and `Eq` are all required trait impls for partition keys
/// #[derive(Clone, Hash, PartialEq, Eq)]
/// pub struct SharedStatePartition {
///     region: String,
///     service_name: String,
/// }
///
/// impl SharedStatePartition {
///     pub fn new(region: impl Into<String>, service_name: impl Into<String>) -> Self {
///         Self { region: region.into(), service_name: service_name.into() }
///     }
/// }
///
/// static SOME_SHARED_STATE: StaticPartitionMap<SharedStatePartition, SomeSharedState> = StaticPartitionMap::new();
///
/// struct Client {
///     shared_state: SomeSharedState,
/// }
///
/// impl Client {
///     pub fn new() -> Self {
///         let key = SharedStatePartition::new("us-east-1", "example_service_20230628");
///         Self {
///             // If the stored value implements `Default`, you can call the
///             // `StaticPartitionMap::get_or_init_default` convenience method.
///             shared_state: SOME_SHARED_STATE.get_or_init_default(key),
///         }
///     }
/// }
/// ```
#[derive(Debug, Default)]
pub struct StaticPartitionMap<K, V> {
    inner: OnceCell<Mutex<HashMap<K, V>>>,
}

impl<K, V> StaticPartitionMap<K, V> {
    /// Creates a new `StaticPartitionMap`.
    pub const fn new() -> Self {
        Self {
            inner: OnceCell::new(),
        }
    }
}

impl<K, V> StaticPartitionMap<K, V>
where
    K: Eq + Hash,
{
    fn get_or_init_inner(&self) -> MutexGuard<'_, HashMap<K, V>> {
        self.inner
            // At the very least, we'll always be storing the default state.
            .get_or_init(|| Mutex::new(HashMap::with_capacity(1)))
            .lock()
            .unwrap()
    }
}

impl<K, V> StaticPartitionMap<K, V>
where
    K: Eq + Hash,
    V: Clone,
{
    /// Gets the value for the given partition key.
    #[must_use]
    pub fn get(&self, partition_key: K) -> Option<V> {
        self.get_or_init_inner().get(&partition_key).cloned()
    }

    /// Gets the value for the given partition key, initializing it with `init` if it doesn't exist.
    #[must_use]
    pub fn get_or_init<F>(&self, partition_key: K, init: F) -> V
    where
        F: FnOnce() -> V,
    {
        let mut inner = self.get_or_init_inner();
        let v = inner.entry(partition_key).or_insert_with(init);
        v.clone()
    }
}

impl<K, V> StaticPartitionMap<K, V>
where
    K: Eq + Hash,
    V: Clone + Default,
{
    /// Gets the value for the given partition key, initializing it if it doesn't exist.
    #[must_use]
    pub fn get_or_init_default(&self, partition_key: K) -> V {
        self.get_or_init(partition_key, V::default)
    }
}

#[cfg(test)]
mod tests {
    use super::StaticPartitionMap;

    #[test]
    fn test_keyed_partition_returns_same_value_for_same_key() {
        let kp = StaticPartitionMap::new();
        let _ = kp.get_or_init("A", || "A".to_owned());
        let actual = kp.get_or_init("A", || "B".to_owned());
        let expected = "A".to_owned();
        assert_eq!(expected, actual);
    }

    #[test]
    fn test_keyed_partition_returns_different_value_for_different_key() {
        let kp = StaticPartitionMap::new();
        let _ = kp.get_or_init("A", || "A".to_owned());
        let actual = kp.get_or_init("B", || "B".to_owned());

        let expected = "B".to_owned();
        assert_eq!(expected, actual);

        let actual = kp.get("A").unwrap();
        let expected = "A".to_owned();
        assert_eq!(expected, actual);
    }
}