Skip to main content

async_utils/async_once/
mod.rs

1// Copyright 2021 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Exposes the OnceCell crate for use in async code.
6
7use async_lock::Mutex;
8use once_cell::sync::OnceCell;
9
10/// Wrapper presenting an async interface to a OnceCell.
11#[derive(Debug)]
12pub struct Once<T> {
13    mutex: Mutex<()>,
14    value: OnceCell<T>,
15}
16
17impl<T> Default for Once<T> {
18    fn default() -> Self {
19        Self { mutex: Mutex::new(()), value: OnceCell::new() }
20    }
21}
22
23impl<T> Once<T> {
24    /// Constructor.
25    pub fn new() -> Self {
26        Self { mutex: Mutex::new(()), value: OnceCell::new() }
27    }
28
29    /// Async wrapper around OnceCell's `get_or_init`.
30    pub async fn get_or_init<'a, F>(&'a self, f: F) -> &'a T
31    where
32        F: AsyncFnOnce() -> T,
33    {
34        if let Some(t) = self.value.get() {
35            t
36        } else {
37            let _mut = self.mutex.lock().await;
38            // Someone raced us and just released the lock
39            if let Some(t) = self.value.get() {
40                t
41            } else {
42                let t = f().await;
43                self.value.set(t).unwrap_or_else(|_| panic!("race in async-cell!"));
44                self.value.get().unwrap()
45            }
46        }
47    }
48
49    /// Async wrapper around OnceCell's `get_or_try_init`.
50    pub async fn get_or_try_init<'a, F, E>(&'a self, f: F) -> Result<&'a T, E>
51    where
52        F: AsyncFnOnce() -> Result<T, E>,
53    {
54        if let Some(t) = self.value.get() {
55            Ok(t)
56        } else {
57            let _mut = self.mutex.lock().await;
58            // Someone raced us and just released the lock
59            if let Some(t) = self.value.get() {
60                Ok(t)
61            } else {
62                let r = f().await;
63                match r {
64                    Ok(t) => {
65                        self.value.set(t).unwrap_or_else(|_| panic!("race in async-cell!"));
66                        Ok(self.value.get().unwrap())
67                    }
68                    Err(e) => Err(e),
69                }
70            }
71        }
72    }
73}
74
75#[cfg(test)]
76mod test {
77    use super::*;
78    use futures_lite::future::block_on;
79    use std::sync::LazyLock;
80    use std::sync::atomic::{AtomicUsize, Ordering};
81
82    static ONCE: LazyLock<Once<bool>> = LazyLock::new(Once::new);
83    static COUNTER: AtomicUsize = AtomicUsize::new(0);
84
85    #[test]
86    fn test_get_or_init() {
87        let val = block_on(ONCE.get_or_init(async || {
88            let _: usize = COUNTER.fetch_add(1, Ordering::SeqCst);
89            true
90        }));
91
92        assert_eq!(*val, true);
93        assert_eq!(COUNTER.load(Ordering::SeqCst), 1);
94
95        let val = block_on(ONCE.get_or_init(async || {
96            let _: usize = COUNTER.fetch_add(1, Ordering::SeqCst);
97            false
98        }));
99
100        assert_eq!(*val, true);
101        assert_eq!(COUNTER.load(Ordering::SeqCst), 1);
102    }
103
104    #[test]
105    fn test_get_or_init_default_initializer() {
106        let val = block_on(ONCE.get_or_init(async || {
107            let _: usize = COUNTER.fetch_add(1, Ordering::SeqCst);
108            true
109        }));
110
111        assert_eq!(*val, true);
112        assert_eq!(COUNTER.load(Ordering::SeqCst), 1);
113
114        let val = block_on(ONCE.get_or_init(async || {
115            let _: usize = COUNTER.fetch_add(1, Ordering::SeqCst);
116            false
117        }));
118
119        assert_eq!(*val, true);
120        assert_eq!(COUNTER.load(Ordering::SeqCst), 1);
121    }
122
123    #[test]
124    fn test_get_or_try_init() {
125        let initializer = async || {
126            let val = COUNTER.fetch_add(1, Ordering::SeqCst);
127            if val == 0 { Err(std::io::Error::other("first attempt fails")) } else { Ok(true) }
128        };
129
130        let val = block_on(ONCE.get_or_try_init(initializer));
131
132        assert!(val.is_err());
133        assert_eq!(COUNTER.load(Ordering::SeqCst), 1);
134
135        // The initializer gets another chance to run because the first attempt failed.
136        let val = block_on(ONCE.get_or_try_init(initializer));
137        assert_eq!(*val.unwrap(), true);
138        assert_eq!(COUNTER.load(Ordering::SeqCst), 2);
139
140        // The initializer never runs again...
141        let val = block_on(ONCE.get_or_try_init(initializer));
142        assert_eq!(*val.unwrap(), true);
143        assert_eq!(COUNTER.load(Ordering::SeqCst), 2);
144    }
145}