1use std::any::{Any, TypeId};
2use std::collections::HashMap;
3use std::sync::Arc;
45/// Pipeline execution context.
6#[derive(Clone, Debug)]
7pub struct Context {
8 type_map: HashMap<TypeId, Arc<dyn Any + Send + Sync>>,
9}
1011impl Default for Context {
12fn default() -> Self {
13Self::new()
14 }
15}
1617impl Context {
18/// Creates a new, empty Context.
19pub fn new() -> Self {
20Self {
21 type_map: HashMap::new(),
22 }
23 }
2425/// Inserts or replaces an entity in the type map. If an entity with the same type was displaced
26 /// by the insert, it will be returned to the caller.
27pub fn insert_or_replace<E>(&mut self, entity: E) -> Option<Arc<E>>
28where
29E: Send + Sync + 'static,
30 {
31// we make sure that for every TypeId of E as key we ALWAYS retrieve an Option<Arc<E>>. That's why
32 // the `unwrap` below is safe.
33self.type_map
34 .insert(TypeId::of::<E>(), Arc::new(entity))
35 .map(|displaced| displaced.downcast().expect("failed to unwrap downcast"))
36 }
3738/// Inserts an entity in the type map. If the an entity with the same type signature is
39 /// already present it will be silently dropped. This function returns a mutable reference to
40 /// the same Context so it can be chained to itself.
41pub fn insert<E>(&mut self, entity: E) -> &mut Self
42where
43E: Send + Sync + 'static,
44 {
45self.type_map.insert(TypeId::of::<E>(), Arc::new(entity));
4647self
48}
4950/// Removes an entity from the type map. If present, the entity will be returned.
51pub fn remove<E>(&mut self) -> Option<Arc<E>>
52where
53E: Send + Sync + 'static,
54 {
55self.type_map
56 .remove(&TypeId::of::<E>())
57 .map(|removed| removed.downcast().expect("failed to unwrap downcast"))
58 }
5960/// Returns a reference of the entity of the specified type signature, if it exists.
61 ///
62 /// If there is no entity with the specific type signature, `None` is returned instead.
63pub fn get<E>(&self) -> Option<&E>
64where
65E: Send + Sync + 'static,
66 {
67self.type_map
68 .get(&TypeId::of::<E>())
69 .and_then(|item| item.downcast_ref())
70 }
7172/// Returns the number of entities in the type map.
73pub fn len(&self) -> usize {
74self.type_map.len()
75 }
7677/// Returns `true` if the type map is empty, `false` otherwise.
78pub fn is_empty(&self) -> bool {
79self.type_map.is_empty()
80 }
81}
8283#[cfg(test)]
84mod tests {
85use super::*;
86use std::sync::Mutex;
8788#[test]
89fn insert_get_string() {
90let mut context = Context::new();
91 context.insert_or_replace("pollo".to_string());
92assert_eq!(Some(&"pollo".to_string()), context.get());
93 }
9495#[test]
96fn insert_get_custom_structs() {
97#[derive(Debug, PartialEq, Eq)]
98struct S1 {}
99#[derive(Debug, PartialEq, Eq)]
100struct S2 {}
101102let mut context = Context::new();
103 context.insert_or_replace(S1 {});
104 context.insert_or_replace(S2 {});
105106assert_eq!(Some(Arc::new(S1 {})), context.insert_or_replace(S1 {}));
107assert_eq!(Some(Arc::new(S2 {})), context.insert_or_replace(S2 {}));
108109assert_eq!(Some(&S1 {}), context.get());
110assert_eq!(Some(&S2 {}), context.get());
111 }
112113#[test]
114fn insert_fluent_syntax() {
115#[derive(Debug, PartialEq, Eq, Default)]
116struct S1 {}
117#[derive(Debug, PartialEq, Eq, Default)]
118struct S2 {}
119120let mut context = Context::new();
121122 context
123 .insert("static str")
124 .insert("a String".to_string())
125 .insert(S1::default())
126 .insert(S1::default()) // notice we are REPLACING S1. This call will *not* increment the counter
127.insert(S2::default());
128129assert_eq!(4, context.len());
130assert_eq!(Some(&"static str"), context.get());
131 }
132133fn require_send_sync<T: Send + Sync>(_: &T) {}
134135#[test]
136fn test_require_send_sync() {
137// this won't compile if Context as a whole is not Send + Sync
138require_send_sync(&Context::new());
139 }
140141#[test]
142fn mutability() {
143#[derive(Debug, PartialEq, Eq, Default)]
144struct S1 {
145 num: u8,
146 }
147let mut context = Context::new();
148 context.insert_or_replace(Mutex::new(S1::default()));
149150// the stored value is 0.
151assert_eq!(0, context.get::<Mutex<S1>>().unwrap().lock().unwrap().num);
152153// we change the number to 42 in a thread safe manner.
154context.get::<Mutex<S1>>().unwrap().lock().unwrap().num = 42;
155156// now the number is 42.
157assert_eq!(42, context.get::<Mutex<S1>>().unwrap().lock().unwrap().num);
158159// we replace the struct with a new one.
160let displaced = context
161 .insert_or_replace(Mutex::new(S1::default()))
162 .unwrap();
163164// the displaced struct still holds 42 as number
165assert_eq!(42, displaced.lock().unwrap().num);
166167// the new struct has 0 has number.
168assert_eq!(0, context.get::<Mutex<S1>>().unwrap().lock().unwrap().num);
169170 context.insert_or_replace(Mutex::new(33u32));
171*context.get::<Mutex<u32>>().unwrap().lock().unwrap() = 42;
172assert_eq!(42, *context.get::<Mutex<u32>>().unwrap().lock().unwrap());
173 }
174}