futures_util/future/future/
shared.rs1use crate::task::{waker_ref, ArcWake};
2use futures_core::future::{FusedFuture, Future};
3use futures_core::task::{Context, Poll, Waker};
4use slab::Slab;
5use std::cell::UnsafeCell;
6use std::fmt;
7use std::hash::Hasher;
8use std::pin::Pin;
9use std::ptr;
10use std::sync::atomic::AtomicUsize;
11use std::sync::atomic::Ordering::{Acquire, SeqCst};
12use std::sync::{Arc, Mutex, Weak};
13
14#[must_use = "futures do nothing unless you `.await` or poll them"]
16pub struct Shared<Fut: Future> {
17 inner: Option<Arc<Inner<Fut>>>,
18 waker_key: usize,
19}
20
21struct Inner<Fut: Future> {
22 future_or_output: UnsafeCell<FutureOrOutput<Fut>>,
23 notifier: Arc<Notifier>,
24}
25
26struct Notifier {
27 state: AtomicUsize,
28 wakers: Mutex<Option<Slab<Option<Waker>>>>,
29}
30
31pub struct WeakShared<Fut: Future>(Weak<Inner<Fut>>);
33
34impl<Fut: Future> Clone for WeakShared<Fut> {
35 fn clone(&self) -> Self {
36 Self(self.0.clone())
37 }
38}
39
40impl<Fut: Future> fmt::Debug for Shared<Fut> {
41 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42 f.debug_struct("Shared")
43 .field("inner", &self.inner)
44 .field("waker_key", &self.waker_key)
45 .finish()
46 }
47}
48
49impl<Fut: Future> fmt::Debug for Inner<Fut> {
50 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
51 f.debug_struct("Inner").finish()
52 }
53}
54
55impl<Fut: Future> fmt::Debug for WeakShared<Fut> {
56 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
57 f.debug_struct("WeakShared").finish()
58 }
59}
60
61enum FutureOrOutput<Fut: Future> {
62 Future(Fut),
63 Output(Fut::Output),
64}
65
66unsafe impl<Fut> Send for Inner<Fut>
67where
68 Fut: Future + Send,
69 Fut::Output: Send + Sync,
70{
71}
72
73unsafe impl<Fut> Sync for Inner<Fut>
74where
75 Fut: Future + Send,
76 Fut::Output: Send + Sync,
77{
78}
79
80const IDLE: usize = 0;
81const POLLING: usize = 1;
82const COMPLETE: usize = 2;
83const POISONED: usize = 3;
84
85const NULL_WAKER_KEY: usize = usize::max_value();
86
87impl<Fut: Future> Shared<Fut> {
88 pub(super) fn new(future: Fut) -> Self {
89 let inner = Inner {
90 future_or_output: UnsafeCell::new(FutureOrOutput::Future(future)),
91 notifier: Arc::new(Notifier {
92 state: AtomicUsize::new(IDLE),
93 wakers: Mutex::new(Some(Slab::new())),
94 }),
95 };
96
97 Self { inner: Some(Arc::new(inner)), waker_key: NULL_WAKER_KEY }
98 }
99}
100
101impl<Fut> Shared<Fut>
102where
103 Fut: Future,
104{
105 pub fn peek(&self) -> Option<&Fut::Output> {
110 if let Some(inner) = self.inner.as_ref() {
111 match inner.notifier.state.load(SeqCst) {
112 COMPLETE => unsafe { return Some(inner.output()) },
113 POISONED => panic!("inner future panicked during poll"),
114 _ => {}
115 }
116 }
117 None
118 }
119
120 pub fn downgrade(&self) -> Option<WeakShared<Fut>> {
124 if let Some(inner) = self.inner.as_ref() {
125 return Some(WeakShared(Arc::downgrade(inner)));
126 }
127 None
128 }
129
130 #[allow(clippy::unnecessary_safety_doc)]
140 pub fn strong_count(&self) -> Option<usize> {
141 self.inner.as_ref().map(|arc| Arc::strong_count(arc))
142 }
143
144 #[allow(clippy::unnecessary_safety_doc)]
154 pub fn weak_count(&self) -> Option<usize> {
155 self.inner.as_ref().map(|arc| Arc::weak_count(arc))
156 }
157
158 pub fn ptr_hash<H: Hasher>(&self, state: &mut H) {
160 match self.inner.as_ref() {
161 Some(arc) => {
162 state.write_u8(1);
163 ptr::hash(Arc::as_ptr(arc), state);
164 }
165 None => {
166 state.write_u8(0);
167 }
168 }
169 }
170
171 pub fn ptr_eq(&self, rhs: &Self) -> bool {
176 let lhs = match self.inner.as_ref() {
177 Some(lhs) => lhs,
178 None => return false,
179 };
180 let rhs = match rhs.inner.as_ref() {
181 Some(rhs) => rhs,
182 None => return false,
183 };
184 Arc::ptr_eq(lhs, rhs)
185 }
186}
187
188impl<Fut> Inner<Fut>
189where
190 Fut: Future,
191{
192 unsafe fn output(&self) -> &Fut::Output {
195 match &*self.future_or_output.get() {
196 FutureOrOutput::Output(ref item) => item,
197 FutureOrOutput::Future(_) => unreachable!(),
198 }
199 }
200}
201
202impl<Fut> Inner<Fut>
203where
204 Fut: Future,
205 Fut::Output: Clone,
206{
207 fn record_waker(&self, waker_key: &mut usize, cx: &mut Context<'_>) {
209 let mut wakers_guard = self.notifier.wakers.lock().unwrap();
210
211 let wakers = match wakers_guard.as_mut() {
212 Some(wakers) => wakers,
213 None => return,
214 };
215
216 let new_waker = cx.waker();
217
218 if *waker_key == NULL_WAKER_KEY {
219 *waker_key = wakers.insert(Some(new_waker.clone()));
220 } else {
221 match wakers[*waker_key] {
222 Some(ref old_waker) if new_waker.will_wake(old_waker) => {}
223 ref mut slot => *slot = Some(new_waker.clone()),
225 }
226 }
227 debug_assert!(*waker_key != NULL_WAKER_KEY);
228 }
229
230 unsafe fn take_or_clone_output(self: Arc<Self>) -> Fut::Output {
233 match Arc::try_unwrap(self) {
234 Ok(inner) => match inner.future_or_output.into_inner() {
235 FutureOrOutput::Output(item) => item,
236 FutureOrOutput::Future(_) => unreachable!(),
237 },
238 Err(inner) => inner.output().clone(),
239 }
240 }
241}
242
243impl<Fut> FusedFuture for Shared<Fut>
244where
245 Fut: Future,
246 Fut::Output: Clone,
247{
248 fn is_terminated(&self) -> bool {
249 self.inner.is_none()
250 }
251}
252
253impl<Fut> Future for Shared<Fut>
254where
255 Fut: Future,
256 Fut::Output: Clone,
257{
258 type Output = Fut::Output;
259
260 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
261 let this = &mut *self;
262
263 let inner = this.inner.take().expect("Shared future polled again after completion");
264
265 if inner.notifier.state.load(Acquire) == COMPLETE {
267 return unsafe { Poll::Ready(inner.take_or_clone_output()) };
269 }
270
271 inner.record_waker(&mut this.waker_key, cx);
272
273 match inner
274 .notifier
275 .state
276 .compare_exchange(IDLE, POLLING, SeqCst, SeqCst)
277 .unwrap_or_else(|x| x)
278 {
279 IDLE => {
280 }
282 POLLING => {
283 this.inner = Some(inner);
286 return Poll::Pending;
287 }
288 COMPLETE => {
289 return unsafe { Poll::Ready(inner.take_or_clone_output()) };
291 }
292 POISONED => panic!("inner future panicked during poll"),
293 _ => unreachable!(),
294 }
295
296 let waker = waker_ref(&inner.notifier);
297 let mut cx = Context::from_waker(&waker);
298
299 struct Reset<'a> {
300 state: &'a AtomicUsize,
301 did_not_panic: bool,
302 }
303
304 impl Drop for Reset<'_> {
305 fn drop(&mut self) {
306 if !self.did_not_panic {
307 self.state.store(POISONED, SeqCst);
308 }
309 }
310 }
311
312 let mut reset = Reset { state: &inner.notifier.state, did_not_panic: false };
313
314 let output = {
315 let future = unsafe {
316 match &mut *inner.future_or_output.get() {
317 FutureOrOutput::Future(fut) => Pin::new_unchecked(fut),
318 _ => unreachable!(),
319 }
320 };
321
322 let poll_result = future.poll(&mut cx);
323 reset.did_not_panic = true;
324
325 match poll_result {
326 Poll::Pending => {
327 if inner.notifier.state.compare_exchange(POLLING, IDLE, SeqCst, SeqCst).is_ok()
328 {
329 drop(reset);
331 this.inner = Some(inner);
332 return Poll::Pending;
333 } else {
334 unreachable!()
335 }
336 }
337 Poll::Ready(output) => output,
338 }
339 };
340
341 unsafe {
342 *inner.future_or_output.get() = FutureOrOutput::Output(output);
343 }
344
345 inner.notifier.state.store(COMPLETE, SeqCst);
346
347 let mut wakers_guard = inner.notifier.wakers.lock().unwrap();
349 let mut wakers = wakers_guard.take().unwrap();
350 for waker in wakers.drain().flatten() {
351 waker.wake();
352 }
353
354 drop(reset); drop(wakers_guard);
356
357 unsafe { Poll::Ready(inner.take_or_clone_output()) }
359 }
360}
361
362impl<Fut> Clone for Shared<Fut>
363where
364 Fut: Future,
365{
366 fn clone(&self) -> Self {
367 Self { inner: self.inner.clone(), waker_key: NULL_WAKER_KEY }
368 }
369}
370
371impl<Fut> Drop for Shared<Fut>
372where
373 Fut: Future,
374{
375 fn drop(&mut self) {
376 if self.waker_key != NULL_WAKER_KEY {
377 if let Some(ref inner) = self.inner {
378 if let Ok(mut wakers) = inner.notifier.wakers.lock() {
379 if let Some(wakers) = wakers.as_mut() {
380 wakers.remove(self.waker_key);
381 }
382 }
383 }
384 }
385 }
386}
387
388impl ArcWake for Notifier {
389 fn wake_by_ref(arc_self: &Arc<Self>) {
390 let wakers = &mut *arc_self.wakers.lock().unwrap();
391 if let Some(wakers) = wakers.as_mut() {
392 for (_key, opt_waker) in wakers {
393 if let Some(waker) = opt_waker.take() {
394 waker.wake();
395 }
396 }
397 }
398 }
399}
400
401impl<Fut: Future> WeakShared<Fut> {
402 pub fn upgrade(&self) -> Option<Shared<Fut>> {
407 Some(Shared { inner: Some(self.0.upgrade()?), waker_key: NULL_WAKER_KEY })
408 }
409}