async_utils/async_once/
mod.rs1use async_lock::Mutex;
8use once_cell::sync::OnceCell;
9use std::future::Future;
10
11#[derive(Debug)]
13pub struct Once<T> {
14 mutex: Mutex<()>,
15 value: OnceCell<T>,
16}
17
18impl<T> Default for Once<T> {
19 fn default() -> Self {
20 Self { mutex: Mutex::new(()), value: OnceCell::new() }
21 }
22}
23
24impl<T> Once<T> {
25 pub fn new() -> Self {
27 Self { mutex: Mutex::new(()), value: OnceCell::new() }
28 }
29
30 pub async fn get_or_init<'a, F>(&'a self, fut: F) -> &'a T
32 where
33 F: Future<Output = T>,
34 {
35 if let Some(t) = self.value.get() {
36 t
37 } else {
38 let _mut = self.mutex.lock().await;
39 if let Some(t) = self.value.get() {
41 t
42 } else {
43 let t = fut.await;
44 self.value.set(t).unwrap_or_else(|_| panic!("race in async-cell!"));
45 self.value.get().unwrap()
46 }
47 }
48 }
49
50 pub async fn get_or_try_init<'a, F, E>(&'a self, fut: F) -> Result<&'a T, E>
52 where
53 F: Future<Output = Result<T, E>>,
54 {
55 if let Some(t) = self.value.get() {
56 Ok(t)
57 } else {
58 let _mut = self.mutex.lock().await;
59 if let Some(t) = self.value.get() {
61 Ok(t)
62 } else {
63 let r = fut.await;
64 match r {
65 Ok(t) => {
66 self.value.set(t).unwrap_or_else(|_| panic!("race in async-cell!"));
67 Ok(self.value.get().unwrap())
68 }
69 Err(e) => Err(e),
70 }
71 }
72 }
73 }
74}
75
76#[cfg(test)]
77mod test {
78 use super::*;
79 use futures_lite::future::block_on;
80 use std::sync::LazyLock;
81 use std::sync::atomic::{AtomicUsize, Ordering};
82
83 static ONCE: LazyLock<Once<bool>> = LazyLock::new(Once::new);
84 static COUNTER: AtomicUsize = AtomicUsize::new(0);
85
86 #[test]
87 fn test_get_or_init() {
88 let val = block_on(ONCE.get_or_init(async {
89 let _: usize = COUNTER.fetch_add(1, Ordering::SeqCst);
90 true
91 }));
92
93 assert_eq!(*val, true);
94 assert_eq!(COUNTER.load(Ordering::SeqCst), 1);
95
96 let val = block_on(ONCE.get_or_init(async {
97 let _: usize = COUNTER.fetch_add(1, Ordering::SeqCst);
98 false
99 }));
100
101 assert_eq!(*val, true);
102 assert_eq!(COUNTER.load(Ordering::SeqCst), 1);
103 }
104
105 #[test]
106 fn test_get_or_init_default_initializer() {
107 let val = block_on(ONCE.get_or_init(async {
108 let _: usize = COUNTER.fetch_add(1, Ordering::SeqCst);
109 true
110 }));
111
112 assert_eq!(*val, true);
113 assert_eq!(COUNTER.load(Ordering::SeqCst), 1);
114
115 let val = block_on(ONCE.get_or_init(async {
116 let _: usize = COUNTER.fetch_add(1, Ordering::SeqCst);
117 false
118 }));
119
120 assert_eq!(*val, true);
121 assert_eq!(COUNTER.load(Ordering::SeqCst), 1);
122 }
123
124 #[test]
125 fn test_get_or_try_init() {
126 let initializer = || async {
127 let val = COUNTER.fetch_add(1, Ordering::SeqCst);
128 if val == 0 { Err(std::io::Error::other("first attempt fails")) } else { Ok(true) }
129 };
130
131 let val = block_on(ONCE.get_or_try_init(initializer()));
132
133 assert!(val.is_err());
134 assert_eq!(COUNTER.load(Ordering::SeqCst), 1);
135
136 let val = block_on(ONCE.get_or_try_init(initializer()));
138 assert_eq!(*val.unwrap(), true);
139 assert_eq!(COUNTER.load(Ordering::SeqCst), 2);
140
141 let val = block_on(ONCE.get_or_try_init(initializer()));
143 assert_eq!(*val.unwrap(), true);
144 assert_eq!(COUNTER.load(Ordering::SeqCst), 2);
145 }
146}