async_utils/async_once/
mod.rs
1use async_lock::Mutex;
8use once_cell::sync::OnceCell;
9use std::future::Future;
10
11#[derive(Debug)]
13pub struct Once<T> {
14 mutex: Mutex<()>,
15 value: OnceCell<T>,
16}
17
18impl<T> Default for Once<T> {
19 fn default() -> Self {
20 Self { mutex: Mutex::new(()), value: OnceCell::new() }
21 }
22}
23
24impl<T> Once<T> {
25 pub fn new() -> Self {
27 Self { mutex: Mutex::new(()), value: OnceCell::new() }
28 }
29
30 pub async fn get_or_init<'a, F>(&'a self, fut: F) -> &'a T
32 where
33 F: Future<Output = T>,
34 {
35 if let Some(t) = self.value.get() {
36 t
37 } else {
38 let _mut = self.mutex.lock().await;
39 if let Some(t) = self.value.get() {
41 t
42 } else {
43 let t = fut.await;
44 self.value.set(t).unwrap_or_else(|_| panic!("race in async-cell!"));
45 self.value.get().unwrap()
46 }
47 }
48 }
49
50 pub async fn get_or_try_init<'a, F, E>(&'a self, fut: F) -> Result<&'a T, E>
52 where
53 F: Future<Output = Result<T, E>>,
54 {
55 if let Some(t) = self.value.get() {
56 Ok(t)
57 } else {
58 let _mut = self.mutex.lock().await;
59 if let Some(t) = self.value.get() {
61 Ok(t)
62 } else {
63 let r = fut.await;
64 match r {
65 Ok(t) => {
66 self.value.set(t).unwrap_or_else(|_| panic!("race in async-cell!"));
67 Ok(self.value.get().unwrap())
68 }
69 Err(e) => Err(e),
70 }
71 }
72 }
73 }
74}
75
76#[cfg(test)]
77mod test {
78 use super::*;
79 use futures_lite::future::block_on;
80 use std::sync::atomic::{AtomicUsize, Ordering};
81
82 #[test]
83 fn test_get_or_init() {
84 lazy_static::lazy_static!(
85 static ref ONCE: Once<bool> = Once::new();
86 );
87
88 static COUNTER: AtomicUsize = AtomicUsize::new(0);
89
90 let val = block_on(ONCE.get_or_init(async {
91 let _: usize = COUNTER.fetch_add(1, Ordering::SeqCst);
92 true
93 }));
94
95 assert_eq!(*val, true);
96 assert_eq!(COUNTER.load(Ordering::SeqCst), 1);
97
98 let val = block_on(ONCE.get_or_init(async {
99 let _: usize = COUNTER.fetch_add(1, Ordering::SeqCst);
100 false
101 }));
102
103 assert_eq!(*val, true);
104 assert_eq!(COUNTER.load(Ordering::SeqCst), 1);
105 }
106
107 #[test]
108 fn test_get_or_init_default_initializer() {
109 lazy_static::lazy_static!(
110 static ref ONCE: Once<bool> = Once::default();
111 );
112
113 static COUNTER: AtomicUsize = AtomicUsize::new(0);
114
115 let val = block_on(ONCE.get_or_init(async {
116 let _: usize = COUNTER.fetch_add(1, Ordering::SeqCst);
117 true
118 }));
119
120 assert_eq!(*val, true);
121 assert_eq!(COUNTER.load(Ordering::SeqCst), 1);
122
123 let val = block_on(ONCE.get_or_init(async {
124 let _: usize = COUNTER.fetch_add(1, Ordering::SeqCst);
125 false
126 }));
127
128 assert_eq!(*val, true);
129 assert_eq!(COUNTER.load(Ordering::SeqCst), 1);
130 }
131
132 #[test]
133 fn test_get_or_try_init() {
134 lazy_static::lazy_static!(
135 static ref ONCE: Once<bool> = Once::new();
136 );
137
138 static COUNTER: AtomicUsize = AtomicUsize::new(0);
139
140 let initializer = || async {
141 let val = COUNTER.fetch_add(1, Ordering::SeqCst);
142 if val == 0 {
143 Err(std::io::Error::other("first attempt fails"))
144 } else {
145 Ok(true)
146 }
147 };
148
149 let val = block_on(ONCE.get_or_try_init(initializer()));
150
151 assert!(val.is_err());
152 assert_eq!(COUNTER.load(Ordering::SeqCst), 1);
153
154 let val = block_on(ONCE.get_or_try_init(initializer()));
156 assert_eq!(*val.unwrap(), true);
157 assert_eq!(COUNTER.load(Ordering::SeqCst), 2);
158
159 let val = block_on(ONCE.get_or_try_init(initializer()));
161 assert_eq!(*val.unwrap(), true);
162 assert_eq!(COUNTER.load(Ordering::SeqCst), 2);
163 }
164}