azure_core/
context.rs

1use std::any::{Any, TypeId};
2use std::collections::HashMap;
3use std::sync::Arc;
4
5/// Pipeline execution context.
6#[derive(Clone, Debug)]
7pub struct Context {
8    type_map: HashMap<TypeId, Arc<dyn Any + Send + Sync>>,
9}
10
11impl Default for Context {
12    fn default() -> Self {
13        Self::new()
14    }
15}
16
17impl Context {
18    /// Creates a new, empty Context.
19    pub fn new() -> Self {
20        Self {
21            type_map: HashMap::new(),
22        }
23    }
24
25    /// Inserts or replaces an entity in the type map. If an entity with the same type was displaced
26    /// by the insert, it will be returned to the caller.
27    pub fn insert_or_replace<E>(&mut self, entity: E) -> Option<Arc<E>>
28    where
29        E: Send + Sync + 'static,
30    {
31        // we make sure that for every TypeId of E as key we ALWAYS retrieve an Option<Arc<E>>. That's why
32        // the `unwrap` below is safe.
33        self.type_map
34            .insert(TypeId::of::<E>(), Arc::new(entity))
35            .map(|displaced| displaced.downcast().expect("failed to unwrap downcast"))
36    }
37
38    /// Inserts an entity in the type map. If the an entity with the same type signature is
39    /// already present it will be silently dropped. This function returns a mutable reference to
40    /// the same Context so it can be chained to itself.
41    pub fn insert<E>(&mut self, entity: E) -> &mut Self
42    where
43        E: Send + Sync + 'static,
44    {
45        self.type_map.insert(TypeId::of::<E>(), Arc::new(entity));
46
47        self
48    }
49
50    /// Removes an entity from the type map. If present, the entity will be returned.
51    pub fn remove<E>(&mut self) -> Option<Arc<E>>
52    where
53        E: Send + Sync + 'static,
54    {
55        self.type_map
56            .remove(&TypeId::of::<E>())
57            .map(|removed| removed.downcast().expect("failed to unwrap downcast"))
58    }
59
60    /// Returns a reference of the entity of the specified type signature, if it exists.
61    ///
62    /// If there is no entity with the specific type signature, `None` is returned instead.
63    pub fn get<E>(&self) -> Option<&E>
64    where
65        E: Send + Sync + 'static,
66    {
67        self.type_map
68            .get(&TypeId::of::<E>())
69            .and_then(|item| item.downcast_ref())
70    }
71
72    /// Returns the number of entities in the type map.
73    pub fn len(&self) -> usize {
74        self.type_map.len()
75    }
76
77    /// Returns `true` if the type map is empty, `false` otherwise.
78    pub fn is_empty(&self) -> bool {
79        self.type_map.is_empty()
80    }
81}
82
83#[cfg(test)]
84mod tests {
85    use super::*;
86    use std::sync::Mutex;
87
88    #[test]
89    fn insert_get_string() {
90        let mut context = Context::new();
91        context.insert_or_replace("pollo".to_string());
92        assert_eq!(Some(&"pollo".to_string()), context.get());
93    }
94
95    #[test]
96    fn insert_get_custom_structs() {
97        #[derive(Debug, PartialEq, Eq)]
98        struct S1 {}
99        #[derive(Debug, PartialEq, Eq)]
100        struct S2 {}
101
102        let mut context = Context::new();
103        context.insert_or_replace(S1 {});
104        context.insert_or_replace(S2 {});
105
106        assert_eq!(Some(Arc::new(S1 {})), context.insert_or_replace(S1 {}));
107        assert_eq!(Some(Arc::new(S2 {})), context.insert_or_replace(S2 {}));
108
109        assert_eq!(Some(&S1 {}), context.get());
110        assert_eq!(Some(&S2 {}), context.get());
111    }
112
113    #[test]
114    fn insert_fluent_syntax() {
115        #[derive(Debug, PartialEq, Eq, Default)]
116        struct S1 {}
117        #[derive(Debug, PartialEq, Eq, Default)]
118        struct S2 {}
119
120        let mut context = Context::new();
121
122        context
123            .insert("static str")
124            .insert("a String".to_string())
125            .insert(S1::default())
126            .insert(S1::default()) // notice we are REPLACING S1. This call will *not* increment the counter
127            .insert(S2::default());
128
129        assert_eq!(4, context.len());
130        assert_eq!(Some(&"static str"), context.get());
131    }
132
133    fn require_send_sync<T: Send + Sync>(_: &T) {}
134
135    #[test]
136    fn test_require_send_sync() {
137        // this won't compile if Context as a whole is not Send + Sync
138        require_send_sync(&Context::new());
139    }
140
141    #[test]
142    fn mutability() {
143        #[derive(Debug, PartialEq, Eq, Default)]
144        struct S1 {
145            num: u8,
146        }
147        let mut context = Context::new();
148        context.insert_or_replace(Mutex::new(S1::default()));
149
150        // the stored value is 0.
151        assert_eq!(0, context.get::<Mutex<S1>>().unwrap().lock().unwrap().num);
152
153        // we change the number to 42 in a thread safe manner.
154        context.get::<Mutex<S1>>().unwrap().lock().unwrap().num = 42;
155
156        // now the number is 42.
157        assert_eq!(42, context.get::<Mutex<S1>>().unwrap().lock().unwrap().num);
158
159        // we replace the struct with a new one.
160        let displaced = context
161            .insert_or_replace(Mutex::new(S1::default()))
162            .unwrap();
163
164        // the displaced struct still holds 42 as number
165        assert_eq!(42, displaced.lock().unwrap().num);
166
167        // the new struct has 0 has number.
168        assert_eq!(0, context.get::<Mutex<S1>>().unwrap().lock().unwrap().num);
169
170        context.insert_or_replace(Mutex::new(33u32));
171        *context.get::<Mutex<u32>>().unwrap().lock().unwrap() = 42;
172        assert_eq!(42, *context.get::<Mutex<u32>>().unwrap().lock().unwrap());
173    }
174}