async_utils/async_once/
mod.rs1use async_lock::Mutex;
8use once_cell::sync::OnceCell;
9
10#[derive(Debug)]
12pub struct Once<T> {
13 mutex: Mutex<()>,
14 value: OnceCell<T>,
15}
16
17impl<T> Default for Once<T> {
18 fn default() -> Self {
19 Self { mutex: Mutex::new(()), value: OnceCell::new() }
20 }
21}
22
23impl<T> Once<T> {
24 pub fn new() -> Self {
26 Self { mutex: Mutex::new(()), value: OnceCell::new() }
27 }
28
29 pub async fn get_or_init<'a, F>(&'a self, f: F) -> &'a T
31 where
32 F: AsyncFnOnce() -> T,
33 {
34 if let Some(t) = self.value.get() {
35 t
36 } else {
37 let _mut = self.mutex.lock().await;
38 if let Some(t) = self.value.get() {
40 t
41 } else {
42 let t = f().await;
43 self.value.set(t).unwrap_or_else(|_| panic!("race in async-cell!"));
44 self.value.get().unwrap()
45 }
46 }
47 }
48
49 pub async fn get_or_try_init<'a, F, E>(&'a self, f: F) -> Result<&'a T, E>
51 where
52 F: AsyncFnOnce() -> Result<T, E>,
53 {
54 if let Some(t) = self.value.get() {
55 Ok(t)
56 } else {
57 let _mut = self.mutex.lock().await;
58 if let Some(t) = self.value.get() {
60 Ok(t)
61 } else {
62 let r = f().await;
63 match r {
64 Ok(t) => {
65 self.value.set(t).unwrap_or_else(|_| panic!("race in async-cell!"));
66 Ok(self.value.get().unwrap())
67 }
68 Err(e) => Err(e),
69 }
70 }
71 }
72 }
73}
74
75#[cfg(test)]
76mod test {
77 use super::*;
78 use futures_lite::future::block_on;
79 use std::sync::LazyLock;
80 use std::sync::atomic::{AtomicUsize, Ordering};
81
82 static ONCE: LazyLock<Once<bool>> = LazyLock::new(Once::new);
83 static COUNTER: AtomicUsize = AtomicUsize::new(0);
84
85 #[test]
86 fn test_get_or_init() {
87 let val = block_on(ONCE.get_or_init(async || {
88 let _: usize = COUNTER.fetch_add(1, Ordering::SeqCst);
89 true
90 }));
91
92 assert_eq!(*val, true);
93 assert_eq!(COUNTER.load(Ordering::SeqCst), 1);
94
95 let val = block_on(ONCE.get_or_init(async || {
96 let _: usize = COUNTER.fetch_add(1, Ordering::SeqCst);
97 false
98 }));
99
100 assert_eq!(*val, true);
101 assert_eq!(COUNTER.load(Ordering::SeqCst), 1);
102 }
103
104 #[test]
105 fn test_get_or_init_default_initializer() {
106 let val = block_on(ONCE.get_or_init(async || {
107 let _: usize = COUNTER.fetch_add(1, Ordering::SeqCst);
108 true
109 }));
110
111 assert_eq!(*val, true);
112 assert_eq!(COUNTER.load(Ordering::SeqCst), 1);
113
114 let val = block_on(ONCE.get_or_init(async || {
115 let _: usize = COUNTER.fetch_add(1, Ordering::SeqCst);
116 false
117 }));
118
119 assert_eq!(*val, true);
120 assert_eq!(COUNTER.load(Ordering::SeqCst), 1);
121 }
122
123 #[test]
124 fn test_get_or_try_init() {
125 let initializer = async || {
126 let val = COUNTER.fetch_add(1, Ordering::SeqCst);
127 if val == 0 { Err(std::io::Error::other("first attempt fails")) } else { Ok(true) }
128 };
129
130 let val = block_on(ONCE.get_or_try_init(initializer));
131
132 assert!(val.is_err());
133 assert_eq!(COUNTER.load(Ordering::SeqCst), 1);
134
135 let val = block_on(ONCE.get_or_try_init(initializer));
137 assert_eq!(*val.unwrap(), true);
138 assert_eq!(COUNTER.load(Ordering::SeqCst), 2);
139
140 let val = block_on(ONCE.get_or_try_init(initializer));
142 assert_eq!(*val.unwrap(), true);
143 assert_eq!(COUNTER.load(Ordering::SeqCst), 2);
144 }
145}