expando/
lib.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use starnix_sync::Mutex;
6use std::any::{Any, TypeId};
7use std::collections::BTreeMap;
8use std::marker::{Send, Sync};
9use std::ops::Deref;
10use std::sync::Arc;
11
12/// A spot in an `Expando`.
13///
14/// Holds a value of type `Arc<T>`.
15#[derive(Debug)]
16struct ExpandoSlot {
17    value: Arc<dyn Any + Send + Sync>,
18}
19
20impl ExpandoSlot {
21    fn new(value: Arc<dyn Any + Send + Sync>) -> Self {
22        ExpandoSlot { value }
23    }
24
25    fn downcast<T: Any + Send + Sync>(&self) -> Option<Arc<T>> {
26        self.value.clone().downcast::<T>().ok()
27    }
28}
29
30/// A lazy collection of values of every type.
31///
32/// An Expando contains a single instance of every type. The values are instantiated lazily
33/// when accessed. Useful for letting modules add their own state to context objects without
34/// requiring the context object itself to know about the types in every module.
35///
36/// Typically the type a module uses in the Expando will be private to that module, which lets
37/// the module know that no other code is accessing its slot on the expando.
38#[derive(Debug, Default)]
39pub struct Expando {
40    properties: Mutex<BTreeMap<TypeId, ExpandoSlot>>,
41}
42
43impl Expando {
44    /// Get the slot in the expando associated with the given type.
45    ///
46    /// The slot is added to the expando lazily but the same instance is returned every time the
47    /// expando is queried for the same type.
48    pub fn get<T: Any + Send + Sync + Default + 'static>(&self) -> Arc<T> {
49        let mut properties = self.properties.lock();
50        let type_id = TypeId::of::<T>();
51        let slot =
52            properties.entry(type_id).or_insert_with(|| ExpandoSlot::new(Arc::new(T::default())));
53        assert_eq!(type_id, slot.value.deref().type_id());
54        slot.downcast().expect("downcast of expando slot was successful")
55    }
56
57    /// Get the slot in the expando associated with the given type, running `init` to initialize
58    /// the slot if needed.
59    ///
60    /// The slot is added to the expando lazily but the same instance is returned every time the
61    /// expando is queried for the same type.
62    pub fn get_or_init<T: Any + Send + Sync + 'static>(&self, init: impl FnOnce() -> T) -> Arc<T> {
63        self.get_or_try_init::<T, ()>(|| Ok(init())).expect("infallible initializer")
64    }
65
66    /// Get the slot in the expando associated with the given type, running `try_init` to initialize
67    /// the slot if needed. Returns an error only if `try_init` returns an error.
68    ///
69    /// The slot is added to the expando lazily but the same instance is returned every time the
70    /// expando is queried for the same type.
71    pub fn get_or_try_init<T: Any + Send + Sync + 'static, E>(
72        &self,
73        try_init: impl FnOnce() -> Result<T, E>,
74    ) -> Result<Arc<T>, E> {
75        let type_id = TypeId::of::<T>();
76
77        // Acquire the lock each time we want to look at the map so that user-provided initializer
78        // can use the expando too.
79        if let Some(slot) = self.properties.lock().get(&type_id) {
80            assert_eq!(type_id, slot.value.deref().type_id());
81            return Ok(slot.downcast().expect("downcast of expando slot was successful"));
82        }
83
84        // Initialize the new value without holding the lock.
85        let newly_init = Arc::new(try_init()?);
86
87        // Only insert the newly-initialized value if no other threads got there first.
88        let mut properties = self.properties.lock();
89        let slot = properties.entry(type_id).or_insert_with(|| ExpandoSlot::new(newly_init));
90        assert_eq!(type_id, slot.value.deref().type_id());
91        Ok(slot.downcast().expect("downcast of expando slot was successful"))
92    }
93
94    /// Get the slot in the expando associated with the given type if it has previously been
95    /// initialized.
96    pub fn peek<T: Any + Send + Sync + 'static>(&self) -> Option<Arc<T>> {
97        let properties = self.properties.lock();
98        let type_id = TypeId::of::<T>();
99        let slot = properties.get(&type_id)?;
100        assert_eq!(type_id, slot.value.deref().type_id());
101        Some(slot.downcast().expect("downcast of expando slot was successful"))
102    }
103}
104
105#[cfg(test)]
106mod tests {
107    use super::*;
108
109    #[derive(Debug, Default)]
110    struct MyStruct {
111        counter: Mutex<i32>,
112    }
113
114    #[test]
115    fn basic_test() {
116        let expando = Expando::default();
117        let first = expando.get::<MyStruct>();
118        assert_eq!(*first.counter.lock(), 0);
119        *first.counter.lock() += 1;
120        let second = expando.get::<MyStruct>();
121        assert_eq!(*second.counter.lock(), 1);
122    }
123
124    #[test]
125    fn user_initializer() {
126        let expando = Expando::default();
127        let first = expando.get_or_init(|| String::from("hello"));
128        assert_eq!(first.as_str(), "hello");
129        let second = expando.get_or_init(|| String::from("world"));
130        assert_eq!(
131            second.as_str(),
132            "hello",
133            "expando must have preserved value from original initializer"
134        );
135        assert_eq!(Arc::as_ptr(&first), Arc::as_ptr(&second));
136    }
137
138    #[test]
139    fn nested_user_initializer() {
140        let expando = Expando::default();
141        let first = expando.get_or_init(|| expando.get::<u32>().to_string());
142        assert_eq!(first.as_str(), "0");
143        let second = expando.get_or_init(|| expando.get::<u32>().to_string());
144        assert_eq!(Arc::as_ptr(&first), Arc::as_ptr(&second));
145    }
146
147    #[test]
148    fn failed_init_can_be_retried() {
149        let expando = Expando::default();
150        let failed = expando.get_or_try_init::<String, String>(|| Err(String::from("oops")));
151        assert_eq!(failed.unwrap_err().as_str(), "oops");
152
153        let succeeded = expando.get_or_try_init::<String, String>(|| Ok(String::from("hurray")));
154        assert_eq!(succeeded.unwrap().as_str(), "hurray");
155    }
156
157    #[test]
158    fn peek_works() {
159        let expando = Expando::default();
160        assert_eq!(expando.peek::<String>(), None);
161        let from_init = expando.get_or_init(|| String::from("hello"));
162        let from_peek = expando.peek::<String>().unwrap();
163        assert_eq!(from_peek.as_str(), "hello");
164        assert_eq!(Arc::as_ptr(&from_init), Arc::as_ptr(&from_peek));
165    }
166}