futures_util/future/future/
shared.rs

1use crate::task::{waker_ref, ArcWake};
2use futures_core::future::{FusedFuture, Future};
3use futures_core::task::{Context, Poll, Waker};
4use slab::Slab;
5use std::cell::UnsafeCell;
6use std::fmt;
7use std::hash::Hasher;
8use std::pin::Pin;
9use std::ptr;
10use std::sync::atomic::AtomicUsize;
11use std::sync::atomic::Ordering::{Acquire, SeqCst};
12use std::sync::{Arc, Mutex, Weak};
13
14/// Future for the [`shared`](super::FutureExt::shared) method.
15#[must_use = "futures do nothing unless you `.await` or poll them"]
16pub struct Shared<Fut: Future> {
17    inner: Option<Arc<Inner<Fut>>>,
18    waker_key: usize,
19}
20
21struct Inner<Fut: Future> {
22    future_or_output: UnsafeCell<FutureOrOutput<Fut>>,
23    notifier: Arc<Notifier>,
24}
25
26struct Notifier {
27    state: AtomicUsize,
28    wakers: Mutex<Option<Slab<Option<Waker>>>>,
29}
30
31/// A weak reference to a [`Shared`] that can be upgraded much like an `Arc`.
32pub struct WeakShared<Fut: Future>(Weak<Inner<Fut>>);
33
34impl<Fut: Future> Clone for WeakShared<Fut> {
35    fn clone(&self) -> Self {
36        Self(self.0.clone())
37    }
38}
39
40impl<Fut: Future> fmt::Debug for Shared<Fut> {
41    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42        f.debug_struct("Shared")
43            .field("inner", &self.inner)
44            .field("waker_key", &self.waker_key)
45            .finish()
46    }
47}
48
49impl<Fut: Future> fmt::Debug for Inner<Fut> {
50    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
51        f.debug_struct("Inner").finish()
52    }
53}
54
55impl<Fut: Future> fmt::Debug for WeakShared<Fut> {
56    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
57        f.debug_struct("WeakShared").finish()
58    }
59}
60
61enum FutureOrOutput<Fut: Future> {
62    Future(Fut),
63    Output(Fut::Output),
64}
65
66unsafe impl<Fut> Send for Inner<Fut>
67where
68    Fut: Future + Send,
69    Fut::Output: Send + Sync,
70{
71}
72
73unsafe impl<Fut> Sync for Inner<Fut>
74where
75    Fut: Future + Send,
76    Fut::Output: Send + Sync,
77{
78}
79
80const IDLE: usize = 0;
81const POLLING: usize = 1;
82const COMPLETE: usize = 2;
83const POISONED: usize = 3;
84
85const NULL_WAKER_KEY: usize = usize::max_value();
86
87impl<Fut: Future> Shared<Fut> {
88    pub(super) fn new(future: Fut) -> Self {
89        let inner = Inner {
90            future_or_output: UnsafeCell::new(FutureOrOutput::Future(future)),
91            notifier: Arc::new(Notifier {
92                state: AtomicUsize::new(IDLE),
93                wakers: Mutex::new(Some(Slab::new())),
94            }),
95        };
96
97        Self { inner: Some(Arc::new(inner)), waker_key: NULL_WAKER_KEY }
98    }
99}
100
101impl<Fut> Shared<Fut>
102where
103    Fut: Future,
104{
105    /// Returns [`Some`] containing a reference to this [`Shared`]'s output if
106    /// it has already been computed by a clone or [`None`] if it hasn't been
107    /// computed yet or this [`Shared`] already returned its output from
108    /// [`poll`](Future::poll).
109    pub fn peek(&self) -> Option<&Fut::Output> {
110        if let Some(inner) = self.inner.as_ref() {
111            match inner.notifier.state.load(SeqCst) {
112                COMPLETE => unsafe { return Some(inner.output()) },
113                POISONED => panic!("inner future panicked during poll"),
114                _ => {}
115            }
116        }
117        None
118    }
119
120    /// Creates a new [`WeakShared`] for this [`Shared`].
121    ///
122    /// Returns [`None`] if it has already been polled to completion.
123    pub fn downgrade(&self) -> Option<WeakShared<Fut>> {
124        if let Some(inner) = self.inner.as_ref() {
125            return Some(WeakShared(Arc::downgrade(inner)));
126        }
127        None
128    }
129
130    /// Gets the number of strong pointers to this allocation.
131    ///
132    /// Returns [`None`] if it has already been polled to completion.
133    ///
134    /// # Safety
135    ///
136    /// This method by itself is safe, but using it correctly requires extra care. Another thread
137    /// can change the strong count at any time, including potentially between calling this method
138    /// and acting on the result.
139    #[allow(clippy::unnecessary_safety_doc)]
140    pub fn strong_count(&self) -> Option<usize> {
141        self.inner.as_ref().map(|arc| Arc::strong_count(arc))
142    }
143
144    /// Gets the number of weak pointers to this allocation.
145    ///
146    /// Returns [`None`] if it has already been polled to completion.
147    ///
148    /// # Safety
149    ///
150    /// This method by itself is safe, but using it correctly requires extra care. Another thread
151    /// can change the weak count at any time, including potentially between calling this method
152    /// and acting on the result.
153    #[allow(clippy::unnecessary_safety_doc)]
154    pub fn weak_count(&self) -> Option<usize> {
155        self.inner.as_ref().map(|arc| Arc::weak_count(arc))
156    }
157
158    /// Hashes the internal state of this `Shared` in a way that's compatible with `ptr_eq`.
159    pub fn ptr_hash<H: Hasher>(&self, state: &mut H) {
160        match self.inner.as_ref() {
161            Some(arc) => {
162                state.write_u8(1);
163                ptr::hash(Arc::as_ptr(arc), state);
164            }
165            None => {
166                state.write_u8(0);
167            }
168        }
169    }
170
171    /// Returns `true` if the two `Shared`s point to the same future (in a vein similar to
172    /// `Arc::ptr_eq`).
173    ///
174    /// Returns `false` if either `Shared` has terminated.
175    pub fn ptr_eq(&self, rhs: &Self) -> bool {
176        let lhs = match self.inner.as_ref() {
177            Some(lhs) => lhs,
178            None => return false,
179        };
180        let rhs = match rhs.inner.as_ref() {
181            Some(rhs) => rhs,
182            None => return false,
183        };
184        Arc::ptr_eq(lhs, rhs)
185    }
186}
187
188impl<Fut> Inner<Fut>
189where
190    Fut: Future,
191{
192    /// Safety: callers must first ensure that `self.inner.state`
193    /// is `COMPLETE`
194    unsafe fn output(&self) -> &Fut::Output {
195        match &*self.future_or_output.get() {
196            FutureOrOutput::Output(ref item) => item,
197            FutureOrOutput::Future(_) => unreachable!(),
198        }
199    }
200}
201
202impl<Fut> Inner<Fut>
203where
204    Fut: Future,
205    Fut::Output: Clone,
206{
207    /// Registers the current task to receive a wakeup when we are awoken.
208    fn record_waker(&self, waker_key: &mut usize, cx: &mut Context<'_>) {
209        let mut wakers_guard = self.notifier.wakers.lock().unwrap();
210
211        let wakers = match wakers_guard.as_mut() {
212            Some(wakers) => wakers,
213            None => return,
214        };
215
216        let new_waker = cx.waker();
217
218        if *waker_key == NULL_WAKER_KEY {
219            *waker_key = wakers.insert(Some(new_waker.clone()));
220        } else {
221            match wakers[*waker_key] {
222                Some(ref old_waker) if new_waker.will_wake(old_waker) => {}
223                // Could use clone_from here, but Waker doesn't specialize it.
224                ref mut slot => *slot = Some(new_waker.clone()),
225            }
226        }
227        debug_assert!(*waker_key != NULL_WAKER_KEY);
228    }
229
230    /// Safety: callers must first ensure that `inner.state`
231    /// is `COMPLETE`
232    unsafe fn take_or_clone_output(self: Arc<Self>) -> Fut::Output {
233        match Arc::try_unwrap(self) {
234            Ok(inner) => match inner.future_or_output.into_inner() {
235                FutureOrOutput::Output(item) => item,
236                FutureOrOutput::Future(_) => unreachable!(),
237            },
238            Err(inner) => inner.output().clone(),
239        }
240    }
241}
242
243impl<Fut> FusedFuture for Shared<Fut>
244where
245    Fut: Future,
246    Fut::Output: Clone,
247{
248    fn is_terminated(&self) -> bool {
249        self.inner.is_none()
250    }
251}
252
253impl<Fut> Future for Shared<Fut>
254where
255    Fut: Future,
256    Fut::Output: Clone,
257{
258    type Output = Fut::Output;
259
260    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
261        let this = &mut *self;
262
263        let inner = this.inner.take().expect("Shared future polled again after completion");
264
265        // Fast path for when the wrapped future has already completed
266        if inner.notifier.state.load(Acquire) == COMPLETE {
267            // Safety: We're in the COMPLETE state
268            return unsafe { Poll::Ready(inner.take_or_clone_output()) };
269        }
270
271        inner.record_waker(&mut this.waker_key, cx);
272
273        match inner
274            .notifier
275            .state
276            .compare_exchange(IDLE, POLLING, SeqCst, SeqCst)
277            .unwrap_or_else(|x| x)
278        {
279            IDLE => {
280                // Lock acquired, fall through
281            }
282            POLLING => {
283                // Another task is currently polling, at this point we just want
284                // to ensure that the waker for this task is registered
285                this.inner = Some(inner);
286                return Poll::Pending;
287            }
288            COMPLETE => {
289                // Safety: We're in the COMPLETE state
290                return unsafe { Poll::Ready(inner.take_or_clone_output()) };
291            }
292            POISONED => panic!("inner future panicked during poll"),
293            _ => unreachable!(),
294        }
295
296        let waker = waker_ref(&inner.notifier);
297        let mut cx = Context::from_waker(&waker);
298
299        struct Reset<'a> {
300            state: &'a AtomicUsize,
301            did_not_panic: bool,
302        }
303
304        impl Drop for Reset<'_> {
305            fn drop(&mut self) {
306                if !self.did_not_panic {
307                    self.state.store(POISONED, SeqCst);
308                }
309            }
310        }
311
312        let mut reset = Reset { state: &inner.notifier.state, did_not_panic: false };
313
314        let output = {
315            let future = unsafe {
316                match &mut *inner.future_or_output.get() {
317                    FutureOrOutput::Future(fut) => Pin::new_unchecked(fut),
318                    _ => unreachable!(),
319                }
320            };
321
322            let poll_result = future.poll(&mut cx);
323            reset.did_not_panic = true;
324
325            match poll_result {
326                Poll::Pending => {
327                    if inner.notifier.state.compare_exchange(POLLING, IDLE, SeqCst, SeqCst).is_ok()
328                    {
329                        // Success
330                        drop(reset);
331                        this.inner = Some(inner);
332                        return Poll::Pending;
333                    } else {
334                        unreachable!()
335                    }
336                }
337                Poll::Ready(output) => output,
338            }
339        };
340
341        unsafe {
342            *inner.future_or_output.get() = FutureOrOutput::Output(output);
343        }
344
345        inner.notifier.state.store(COMPLETE, SeqCst);
346
347        // Wake all tasks and drop the slab
348        let mut wakers_guard = inner.notifier.wakers.lock().unwrap();
349        let mut wakers = wakers_guard.take().unwrap();
350        for waker in wakers.drain().flatten() {
351            waker.wake();
352        }
353
354        drop(reset); // Make borrow checker happy
355        drop(wakers_guard);
356
357        // Safety: We're in the COMPLETE state
358        unsafe { Poll::Ready(inner.take_or_clone_output()) }
359    }
360}
361
362impl<Fut> Clone for Shared<Fut>
363where
364    Fut: Future,
365{
366    fn clone(&self) -> Self {
367        Self { inner: self.inner.clone(), waker_key: NULL_WAKER_KEY }
368    }
369}
370
371impl<Fut> Drop for Shared<Fut>
372where
373    Fut: Future,
374{
375    fn drop(&mut self) {
376        if self.waker_key != NULL_WAKER_KEY {
377            if let Some(ref inner) = self.inner {
378                if let Ok(mut wakers) = inner.notifier.wakers.lock() {
379                    if let Some(wakers) = wakers.as_mut() {
380                        wakers.remove(self.waker_key);
381                    }
382                }
383            }
384        }
385    }
386}
387
388impl ArcWake for Notifier {
389    fn wake_by_ref(arc_self: &Arc<Self>) {
390        let wakers = &mut *arc_self.wakers.lock().unwrap();
391        if let Some(wakers) = wakers.as_mut() {
392            for (_key, opt_waker) in wakers {
393                if let Some(waker) = opt_waker.take() {
394                    waker.wake();
395                }
396            }
397        }
398    }
399}
400
401impl<Fut: Future> WeakShared<Fut> {
402    /// Attempts to upgrade this [`WeakShared`] into a [`Shared`].
403    ///
404    /// Returns [`None`] if all clones of the [`Shared`] have been dropped or polled
405    /// to completion.
406    pub fn upgrade(&self) -> Option<Shared<Fut>> {
407        Some(Shared { inner: Some(self.0.upgrade()?), waker_key: NULL_WAKER_KEY })
408    }
409}