async_utils/async_once/
mod.rs

1// Copyright 2021 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Exposes the OnceCell crate for use in async code.
6
7use async_lock::Mutex;
8use once_cell::sync::OnceCell;
9use std::future::Future;
10
11/// Wrapper presenting an async interface to a OnceCell.
12#[derive(Debug)]
13pub struct Once<T> {
14    mutex: Mutex<()>,
15    value: OnceCell<T>,
16}
17
18impl<T> Default for Once<T> {
19    fn default() -> Self {
20        Self { mutex: Mutex::new(()), value: OnceCell::new() }
21    }
22}
23
24impl<T> Once<T> {
25    /// Constructor.
26    pub fn new() -> Self {
27        Self { mutex: Mutex::new(()), value: OnceCell::new() }
28    }
29
30    /// Async wrapper around OnceCell's `get_or_init`.
31    pub async fn get_or_init<'a, F>(&'a self, fut: F) -> &'a T
32    where
33        F: Future<Output = T>,
34    {
35        if let Some(t) = self.value.get() {
36            t
37        } else {
38            let _mut = self.mutex.lock().await;
39            // Someone raced us and just released the lock
40            if let Some(t) = self.value.get() {
41                t
42            } else {
43                let t = fut.await;
44                self.value.set(t).unwrap_or_else(|_| panic!("race in async-cell!"));
45                self.value.get().unwrap()
46            }
47        }
48    }
49
50    /// Async wrapper around OnceCell's `get_or_try_init`.
51    pub async fn get_or_try_init<'a, F, E>(&'a self, fut: F) -> Result<&'a T, E>
52    where
53        F: Future<Output = Result<T, E>>,
54    {
55        if let Some(t) = self.value.get() {
56            Ok(t)
57        } else {
58            let _mut = self.mutex.lock().await;
59            // Someone raced us and just released the lock
60            if let Some(t) = self.value.get() {
61                Ok(t)
62            } else {
63                let r = fut.await;
64                match r {
65                    Ok(t) => {
66                        self.value.set(t).unwrap_or_else(|_| panic!("race in async-cell!"));
67                        Ok(self.value.get().unwrap())
68                    }
69                    Err(e) => Err(e),
70                }
71            }
72        }
73    }
74}
75
76#[cfg(test)]
77mod test {
78    use super::*;
79    use futures_lite::future::block_on;
80    use std::sync::LazyLock;
81    use std::sync::atomic::{AtomicUsize, Ordering};
82
83    static ONCE: LazyLock<Once<bool>> = LazyLock::new(Once::new);
84    static COUNTER: AtomicUsize = AtomicUsize::new(0);
85
86    #[test]
87    fn test_get_or_init() {
88        let val = block_on(ONCE.get_or_init(async {
89            let _: usize = COUNTER.fetch_add(1, Ordering::SeqCst);
90            true
91        }));
92
93        assert_eq!(*val, true);
94        assert_eq!(COUNTER.load(Ordering::SeqCst), 1);
95
96        let val = block_on(ONCE.get_or_init(async {
97            let _: usize = COUNTER.fetch_add(1, Ordering::SeqCst);
98            false
99        }));
100
101        assert_eq!(*val, true);
102        assert_eq!(COUNTER.load(Ordering::SeqCst), 1);
103    }
104
105    #[test]
106    fn test_get_or_init_default_initializer() {
107        let val = block_on(ONCE.get_or_init(async {
108            let _: usize = COUNTER.fetch_add(1, Ordering::SeqCst);
109            true
110        }));
111
112        assert_eq!(*val, true);
113        assert_eq!(COUNTER.load(Ordering::SeqCst), 1);
114
115        let val = block_on(ONCE.get_or_init(async {
116            let _: usize = COUNTER.fetch_add(1, Ordering::SeqCst);
117            false
118        }));
119
120        assert_eq!(*val, true);
121        assert_eq!(COUNTER.load(Ordering::SeqCst), 1);
122    }
123
124    #[test]
125    fn test_get_or_try_init() {
126        let initializer = || async {
127            let val = COUNTER.fetch_add(1, Ordering::SeqCst);
128            if val == 0 { Err(std::io::Error::other("first attempt fails")) } else { Ok(true) }
129        };
130
131        let val = block_on(ONCE.get_or_try_init(initializer()));
132
133        assert!(val.is_err());
134        assert_eq!(COUNTER.load(Ordering::SeqCst), 1);
135
136        // The initializer gets another chance to run because the first attempt failed.
137        let val = block_on(ONCE.get_or_try_init(initializer()));
138        assert_eq!(*val.unwrap(), true);
139        assert_eq!(COUNTER.load(Ordering::SeqCst), 2);
140
141        // The initializer never runs again...
142        let val = block_on(ONCE.get_or_try_init(initializer()));
143        assert_eq!(*val.unwrap(), true);
144        assert_eq!(COUNTER.load(Ordering::SeqCst), 2);
145    }
146}