1use starnix_sync::Mutex;
6use std::any::{Any, TypeId};
7use std::collections::BTreeMap;
8use std::marker::{Send, Sync};
9use std::ops::Deref;
10use std::sync::Arc;
11
12#[derive(Debug)]
16struct ExpandoSlot {
17 value: Arc<dyn Any + Send + Sync>,
18}
19
20impl ExpandoSlot {
21 fn new(value: Arc<dyn Any + Send + Sync>) -> Self {
22 ExpandoSlot { value }
23 }
24
25 fn downcast<T: Any + Send + Sync>(&self) -> Option<Arc<T>> {
26 self.value.clone().downcast::<T>().ok()
27 }
28}
29
30#[derive(Debug, Default)]
39pub struct Expando {
40 properties: Mutex<BTreeMap<TypeId, ExpandoSlot>>,
41}
42
43impl Expando {
44 pub fn get<T: Any + Send + Sync + Default + 'static>(&self) -> Arc<T> {
49 let mut properties = self.properties.lock();
50 let type_id = TypeId::of::<T>();
51 let slot =
52 properties.entry(type_id).or_insert_with(|| ExpandoSlot::new(Arc::new(T::default())));
53 assert_eq!(type_id, slot.value.deref().type_id());
54 slot.downcast().expect("downcast of expando slot was successful")
55 }
56
57 pub fn get_or_init<T: Any + Send + Sync + 'static>(&self, init: impl FnOnce() -> T) -> Arc<T> {
63 self.get_or_try_init::<T, ()>(|| Ok(init())).expect("infallible initializer")
64 }
65
66 pub fn get_or_try_init<T: Any + Send + Sync + 'static, E>(
72 &self,
73 try_init: impl FnOnce() -> Result<T, E>,
74 ) -> Result<Arc<T>, E> {
75 let type_id = TypeId::of::<T>();
76
77 if let Some(slot) = self.properties.lock().get(&type_id) {
80 assert_eq!(type_id, slot.value.deref().type_id());
81 return Ok(slot.downcast().expect("downcast of expando slot was successful"));
82 }
83
84 let newly_init = Arc::new(try_init()?);
86
87 let mut properties = self.properties.lock();
89 let slot = properties.entry(type_id).or_insert_with(|| ExpandoSlot::new(newly_init));
90 assert_eq!(type_id, slot.value.deref().type_id());
91 Ok(slot.downcast().expect("downcast of expando slot was successful"))
92 }
93
94 pub fn peek<T: Any + Send + Sync + 'static>(&self) -> Option<Arc<T>> {
97 let properties = self.properties.lock();
98 let type_id = TypeId::of::<T>();
99 let slot = properties.get(&type_id)?;
100 assert_eq!(type_id, slot.value.deref().type_id());
101 Some(slot.downcast().expect("downcast of expando slot was successful"))
102 }
103}
104
105#[cfg(test)]
106mod tests {
107 use super::*;
108
109 #[derive(Debug, Default)]
110 struct MyStruct {
111 counter: Mutex<i32>,
112 }
113
114 #[test]
115 fn basic_test() {
116 let expando = Expando::default();
117 let first = expando.get::<MyStruct>();
118 assert_eq!(*first.counter.lock(), 0);
119 *first.counter.lock() += 1;
120 let second = expando.get::<MyStruct>();
121 assert_eq!(*second.counter.lock(), 1);
122 }
123
124 #[test]
125 fn user_initializer() {
126 let expando = Expando::default();
127 let first = expando.get_or_init(|| String::from("hello"));
128 assert_eq!(first.as_str(), "hello");
129 let second = expando.get_or_init(|| String::from("world"));
130 assert_eq!(
131 second.as_str(),
132 "hello",
133 "expando must have preserved value from original initializer"
134 );
135 assert_eq!(Arc::as_ptr(&first), Arc::as_ptr(&second));
136 }
137
138 #[test]
139 fn nested_user_initializer() {
140 let expando = Expando::default();
141 let first = expando.get_or_init(|| expando.get::<u32>().to_string());
142 assert_eq!(first.as_str(), "0");
143 let second = expando.get_or_init(|| expando.get::<u32>().to_string());
144 assert_eq!(Arc::as_ptr(&first), Arc::as_ptr(&second));
145 }
146
147 #[test]
148 fn failed_init_can_be_retried() {
149 let expando = Expando::default();
150 let failed = expando.get_or_try_init::<String, String>(|| Err(String::from("oops")));
151 assert_eq!(failed.unwrap_err().as_str(), "oops");
152
153 let succeeded = expando.get_or_try_init::<String, String>(|| Ok(String::from("hurray")));
154 assert_eq!(succeeded.unwrap().as_str(), "hurray");
155 }
156
157 #[test]
158 fn peek_works() {
159 let expando = Expando::default();
160 assert_eq!(expando.peek::<String>(), None);
161 let from_init = expando.get_or_init(|| String::from("hello"));
162 let from_peek = expando.peek::<String>().unwrap();
163 assert_eq!(from_peek.as_str(), "hello");
164 assert_eq!(Arc::as_ptr(&from_init), Arc::as_ptr(&from_peek));
165 }
166}