async_utils/async_once/
mod.rs

1// Copyright 2021 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Exposes the OnceCell crate for use in async code.
6
7use async_lock::Mutex;
8use once_cell::sync::OnceCell;
9use std::future::Future;
10
11/// Wrapper presenting an async interface to a OnceCell.
12#[derive(Debug)]
13pub struct Once<T> {
14    mutex: Mutex<()>,
15    value: OnceCell<T>,
16}
17
18impl<T> Default for Once<T> {
19    fn default() -> Self {
20        Self { mutex: Mutex::new(()), value: OnceCell::new() }
21    }
22}
23
24impl<T> Once<T> {
25    /// Constructor.
26    pub fn new() -> Self {
27        Self { mutex: Mutex::new(()), value: OnceCell::new() }
28    }
29
30    /// Async wrapper around OnceCell's `get_or_init`.
31    pub async fn get_or_init<'a, F>(&'a self, fut: F) -> &'a T
32    where
33        F: Future<Output = T>,
34    {
35        if let Some(t) = self.value.get() {
36            t
37        } else {
38            let _mut = self.mutex.lock().await;
39            // Someone raced us and just released the lock
40            if let Some(t) = self.value.get() {
41                t
42            } else {
43                let t = fut.await;
44                self.value.set(t).unwrap_or_else(|_| panic!("race in async-cell!"));
45                self.value.get().unwrap()
46            }
47        }
48    }
49
50    /// Async wrapper around OnceCell's `get_or_try_init`.
51    pub async fn get_or_try_init<'a, F, E>(&'a self, fut: F) -> Result<&'a T, E>
52    where
53        F: Future<Output = Result<T, E>>,
54    {
55        if let Some(t) = self.value.get() {
56            Ok(t)
57        } else {
58            let _mut = self.mutex.lock().await;
59            // Someone raced us and just released the lock
60            if let Some(t) = self.value.get() {
61                Ok(t)
62            } else {
63                let r = fut.await;
64                match r {
65                    Ok(t) => {
66                        self.value.set(t).unwrap_or_else(|_| panic!("race in async-cell!"));
67                        Ok(self.value.get().unwrap())
68                    }
69                    Err(e) => Err(e),
70                }
71            }
72        }
73    }
74}
75
76#[cfg(test)]
77mod test {
78    use super::*;
79    use futures_lite::future::block_on;
80    use std::sync::atomic::{AtomicUsize, Ordering};
81
82    #[test]
83    fn test_get_or_init() {
84        lazy_static::lazy_static!(
85            static ref ONCE: Once<bool> = Once::new();
86        );
87
88        static COUNTER: AtomicUsize = AtomicUsize::new(0);
89
90        let val = block_on(ONCE.get_or_init(async {
91            let _: usize = COUNTER.fetch_add(1, Ordering::SeqCst);
92            true
93        }));
94
95        assert_eq!(*val, true);
96        assert_eq!(COUNTER.load(Ordering::SeqCst), 1);
97
98        let val = block_on(ONCE.get_or_init(async {
99            let _: usize = COUNTER.fetch_add(1, Ordering::SeqCst);
100            false
101        }));
102
103        assert_eq!(*val, true);
104        assert_eq!(COUNTER.load(Ordering::SeqCst), 1);
105    }
106
107    #[test]
108    fn test_get_or_init_default_initializer() {
109        lazy_static::lazy_static!(
110            static ref ONCE: Once<bool> = Once::default();
111        );
112
113        static COUNTER: AtomicUsize = AtomicUsize::new(0);
114
115        let val = block_on(ONCE.get_or_init(async {
116            let _: usize = COUNTER.fetch_add(1, Ordering::SeqCst);
117            true
118        }));
119
120        assert_eq!(*val, true);
121        assert_eq!(COUNTER.load(Ordering::SeqCst), 1);
122
123        let val = block_on(ONCE.get_or_init(async {
124            let _: usize = COUNTER.fetch_add(1, Ordering::SeqCst);
125            false
126        }));
127
128        assert_eq!(*val, true);
129        assert_eq!(COUNTER.load(Ordering::SeqCst), 1);
130    }
131
132    #[test]
133    fn test_get_or_try_init() {
134        lazy_static::lazy_static!(
135            static ref ONCE: Once<bool> = Once::new();
136        );
137
138        static COUNTER: AtomicUsize = AtomicUsize::new(0);
139
140        let initializer = || async {
141            let val = COUNTER.fetch_add(1, Ordering::SeqCst);
142            if val == 0 {
143                Err(std::io::Error::other("first attempt fails"))
144            } else {
145                Ok(true)
146            }
147        };
148
149        let val = block_on(ONCE.get_or_try_init(initializer()));
150
151        assert!(val.is_err());
152        assert_eq!(COUNTER.load(Ordering::SeqCst), 1);
153
154        // The initializer gets another chance to run because the first attempt failed.
155        let val = block_on(ONCE.get_or_try_init(initializer()));
156        assert_eq!(*val.unwrap(), true);
157        assert_eq!(COUNTER.load(Ordering::SeqCst), 2);
158
159        // The initializer never runs again...
160        let val = block_on(ONCE.get_or_try_init(initializer()));
161        assert_eq!(*val.unwrap(), true);
162        assert_eq!(COUNTER.load(Ordering::SeqCst), 2);
163    }
164}