tokio/sync/mpsc/
chan.rs

1use crate::loom::cell::UnsafeCell;
2use crate::loom::future::AtomicWaker;
3use crate::loom::sync::atomic::AtomicUsize;
4use crate::loom::sync::Arc;
5use crate::runtime::park::CachedParkThread;
6use crate::sync::mpsc::error::TryRecvError;
7use crate::sync::mpsc::{bounded, list, unbounded};
8use crate::sync::notify::Notify;
9use crate::util::cacheline::CachePadded;
10
11use std::fmt;
12use std::process;
13use std::sync::atomic::Ordering::{AcqRel, Acquire, Relaxed, Release};
14use std::task::Poll::{Pending, Ready};
15use std::task::{Context, Poll};
16
17/// Channel sender.
18pub(crate) struct Tx<T, S> {
19    inner: Arc<Chan<T, S>>,
20}
21
22impl<T, S: fmt::Debug> fmt::Debug for Tx<T, S> {
23    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
24        fmt.debug_struct("Tx").field("inner", &self.inner).finish()
25    }
26}
27
28/// Channel receiver.
29pub(crate) struct Rx<T, S: Semaphore> {
30    inner: Arc<Chan<T, S>>,
31}
32
33impl<T, S: Semaphore + fmt::Debug> fmt::Debug for Rx<T, S> {
34    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
35        fmt.debug_struct("Rx").field("inner", &self.inner).finish()
36    }
37}
38
39pub(crate) trait Semaphore {
40    fn is_idle(&self) -> bool;
41
42    fn add_permit(&self);
43
44    fn add_permits(&self, n: usize);
45
46    fn close(&self);
47
48    fn is_closed(&self) -> bool;
49}
50
51pub(super) struct Chan<T, S> {
52    /// Handle to the push half of the lock-free list.
53    tx: CachePadded<list::Tx<T>>,
54
55    /// Receiver waker. Notified when a value is pushed into the channel.
56    rx_waker: CachePadded<AtomicWaker>,
57
58    /// Notifies all tasks listening for the receiver being dropped.
59    notify_rx_closed: Notify,
60
61    /// Coordinates access to channel's capacity.
62    semaphore: S,
63
64    /// Tracks the number of outstanding sender handles.
65    ///
66    /// When this drops to zero, the send half of the channel is closed.
67    tx_count: AtomicUsize,
68
69    /// Tracks the number of outstanding weak sender handles.
70    tx_weak_count: AtomicUsize,
71
72    /// Only accessed by `Rx` handle.
73    rx_fields: UnsafeCell<RxFields<T>>,
74}
75
76impl<T, S> fmt::Debug for Chan<T, S>
77where
78    S: fmt::Debug,
79{
80    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
81        fmt.debug_struct("Chan")
82            .field("tx", &*self.tx)
83            .field("semaphore", &self.semaphore)
84            .field("rx_waker", &*self.rx_waker)
85            .field("tx_count", &self.tx_count)
86            .field("rx_fields", &"...")
87            .finish()
88    }
89}
90
91/// Fields only accessed by `Rx` handle.
92struct RxFields<T> {
93    /// Channel receiver. This field is only accessed by the `Receiver` type.
94    list: list::Rx<T>,
95
96    /// `true` if `Rx::close` is called.
97    rx_closed: bool,
98}
99
100impl<T> fmt::Debug for RxFields<T> {
101    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
102        fmt.debug_struct("RxFields")
103            .field("list", &self.list)
104            .field("rx_closed", &self.rx_closed)
105            .finish()
106    }
107}
108
109unsafe impl<T: Send, S: Send> Send for Chan<T, S> {}
110unsafe impl<T: Send, S: Sync> Sync for Chan<T, S> {}
111
112pub(crate) fn channel<T, S: Semaphore>(semaphore: S) -> (Tx<T, S>, Rx<T, S>) {
113    let (tx, rx) = list::channel();
114
115    let chan = Arc::new(Chan {
116        notify_rx_closed: Notify::new(),
117        tx: CachePadded::new(tx),
118        semaphore,
119        rx_waker: CachePadded::new(AtomicWaker::new()),
120        tx_count: AtomicUsize::new(1),
121        tx_weak_count: AtomicUsize::new(0),
122        rx_fields: UnsafeCell::new(RxFields {
123            list: rx,
124            rx_closed: false,
125        }),
126    });
127
128    (Tx::new(chan.clone()), Rx::new(chan))
129}
130
131// ===== impl Tx =====
132
133impl<T, S> Tx<T, S> {
134    fn new(chan: Arc<Chan<T, S>>) -> Tx<T, S> {
135        Tx { inner: chan }
136    }
137
138    pub(super) fn strong_count(&self) -> usize {
139        self.inner.tx_count.load(Acquire)
140    }
141
142    pub(super) fn weak_count(&self) -> usize {
143        self.inner.tx_weak_count.load(Relaxed)
144    }
145
146    pub(super) fn downgrade(&self) -> Arc<Chan<T, S>> {
147        self.inner.increment_weak_count();
148
149        self.inner.clone()
150    }
151
152    // Returns the upgraded channel or None if the upgrade failed.
153    pub(super) fn upgrade(chan: Arc<Chan<T, S>>) -> Option<Self> {
154        let mut tx_count = chan.tx_count.load(Acquire);
155
156        loop {
157            if tx_count == 0 {
158                // channel is closed
159                return None;
160            }
161
162            match chan
163                .tx_count
164                .compare_exchange_weak(tx_count, tx_count + 1, AcqRel, Acquire)
165            {
166                Ok(_) => return Some(Tx { inner: chan }),
167                Err(prev_count) => tx_count = prev_count,
168            }
169        }
170    }
171
172    pub(super) fn semaphore(&self) -> &S {
173        &self.inner.semaphore
174    }
175
176    /// Send a message and notify the receiver.
177    pub(crate) fn send(&self, value: T) {
178        self.inner.send(value);
179    }
180
181    /// Wake the receive half
182    pub(crate) fn wake_rx(&self) {
183        self.inner.rx_waker.wake();
184    }
185
186    /// Returns `true` if senders belong to the same channel.
187    pub(crate) fn same_channel(&self, other: &Self) -> bool {
188        Arc::ptr_eq(&self.inner, &other.inner)
189    }
190}
191
192impl<T, S: Semaphore> Tx<T, S> {
193    pub(crate) fn is_closed(&self) -> bool {
194        self.inner.semaphore.is_closed()
195    }
196
197    pub(crate) async fn closed(&self) {
198        // In order to avoid a race condition, we first request a notification,
199        // **then** check whether the semaphore is closed. If the semaphore is
200        // closed the notification request is dropped.
201        let notified = self.inner.notify_rx_closed.notified();
202
203        if self.inner.semaphore.is_closed() {
204            return;
205        }
206        notified.await;
207    }
208}
209
210impl<T, S> Clone for Tx<T, S> {
211    fn clone(&self) -> Tx<T, S> {
212        // Using a Relaxed ordering here is sufficient as the caller holds a
213        // strong ref to `self`, preventing a concurrent decrement to zero.
214        self.inner.tx_count.fetch_add(1, Relaxed);
215
216        Tx {
217            inner: self.inner.clone(),
218        }
219    }
220}
221
222impl<T, S> Drop for Tx<T, S> {
223    fn drop(&mut self) {
224        if self.inner.tx_count.fetch_sub(1, AcqRel) != 1 {
225            return;
226        }
227
228        // Close the list, which sends a `Close` message
229        self.inner.tx.close();
230
231        // Notify the receiver
232        self.wake_rx();
233    }
234}
235
236// ===== impl Rx =====
237
238impl<T, S: Semaphore> Rx<T, S> {
239    fn new(chan: Arc<Chan<T, S>>) -> Rx<T, S> {
240        Rx { inner: chan }
241    }
242
243    pub(crate) fn close(&mut self) {
244        self.inner.rx_fields.with_mut(|rx_fields_ptr| {
245            let rx_fields = unsafe { &mut *rx_fields_ptr };
246
247            if rx_fields.rx_closed {
248                return;
249            }
250
251            rx_fields.rx_closed = true;
252        });
253
254        self.inner.semaphore.close();
255        self.inner.notify_rx_closed.notify_waiters();
256    }
257
258    pub(crate) fn is_closed(&self) -> bool {
259        // There two internal states that can represent a closed channel
260        //
261        //  1. When `close` is called.
262        //  In this case, the inner semaphore will be closed.
263        //
264        //  2. When all senders are dropped.
265        //  In this case, the semaphore remains unclosed, and the `index` in the list won't
266        //  reach the tail position. It is necessary to check the list if the last block is
267        //  `closed`.
268        self.inner.semaphore.is_closed() || self.inner.tx_count.load(Acquire) == 0
269    }
270
271    pub(crate) fn is_empty(&self) -> bool {
272        self.inner.rx_fields.with(|rx_fields_ptr| {
273            let rx_fields = unsafe { &*rx_fields_ptr };
274            rx_fields.list.is_empty(&self.inner.tx)
275        })
276    }
277
278    pub(crate) fn len(&self) -> usize {
279        self.inner.rx_fields.with(|rx_fields_ptr| {
280            let rx_fields = unsafe { &*rx_fields_ptr };
281            rx_fields.list.len(&self.inner.tx)
282        })
283    }
284
285    /// Receive the next value
286    pub(crate) fn recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> {
287        use super::block::Read;
288
289        ready!(crate::trace::trace_leaf(cx));
290
291        // Keep track of task budget
292        let coop = ready!(crate::runtime::coop::poll_proceed(cx));
293
294        self.inner.rx_fields.with_mut(|rx_fields_ptr| {
295            let rx_fields = unsafe { &mut *rx_fields_ptr };
296
297            macro_rules! try_recv {
298                () => {
299                    match rx_fields.list.pop(&self.inner.tx) {
300                        Some(Read::Value(value)) => {
301                            self.inner.semaphore.add_permit();
302                            coop.made_progress();
303                            return Ready(Some(value));
304                        }
305                        Some(Read::Closed) => {
306                            // TODO: This check may not be required as it most
307                            // likely can only return `true` at this point. A
308                            // channel is closed when all tx handles are
309                            // dropped. Dropping a tx handle releases memory,
310                            // which ensures that if dropping the tx handle is
311                            // visible, then all messages sent are also visible.
312                            assert!(self.inner.semaphore.is_idle());
313                            coop.made_progress();
314                            return Ready(None);
315                        }
316                        None => {} // fall through
317                    }
318                };
319            }
320
321            try_recv!();
322
323            self.inner.rx_waker.register_by_ref(cx.waker());
324
325            // It is possible that a value was pushed between attempting to read
326            // and registering the task, so we have to check the channel a
327            // second time here.
328            try_recv!();
329
330            if rx_fields.rx_closed && self.inner.semaphore.is_idle() {
331                coop.made_progress();
332                Ready(None)
333            } else {
334                Pending
335            }
336        })
337    }
338
339    /// Receives up to `limit` values into `buffer`
340    ///
341    /// For `limit > 0`, receives up to limit values into `buffer`.
342    /// For `limit == 0`, immediately returns Ready(0).
343    pub(crate) fn recv_many(
344        &mut self,
345        cx: &mut Context<'_>,
346        buffer: &mut Vec<T>,
347        limit: usize,
348    ) -> Poll<usize> {
349        use super::block::Read;
350
351        ready!(crate::trace::trace_leaf(cx));
352
353        // Keep track of task budget
354        let coop = ready!(crate::runtime::coop::poll_proceed(cx));
355
356        if limit == 0 {
357            coop.made_progress();
358            return Ready(0usize);
359        }
360
361        let mut remaining = limit;
362        let initial_length = buffer.len();
363
364        self.inner.rx_fields.with_mut(|rx_fields_ptr| {
365            let rx_fields = unsafe { &mut *rx_fields_ptr };
366            macro_rules! try_recv {
367                () => {
368                    while remaining > 0 {
369                        match rx_fields.list.pop(&self.inner.tx) {
370                            Some(Read::Value(value)) => {
371                                remaining -= 1;
372                                buffer.push(value);
373                            }
374
375                            Some(Read::Closed) => {
376                                let number_added = buffer.len() - initial_length;
377                                if number_added > 0 {
378                                    self.inner.semaphore.add_permits(number_added);
379                                }
380                                // TODO: This check may not be required as it most
381                                // likely can only return `true` at this point. A
382                                // channel is closed when all tx handles are
383                                // dropped. Dropping a tx handle releases memory,
384                                // which ensures that if dropping the tx handle is
385                                // visible, then all messages sent are also visible.
386                                assert!(self.inner.semaphore.is_idle());
387                                coop.made_progress();
388                                return Ready(number_added);
389                            }
390
391                            None => {
392                                break; // fall through
393                            }
394                        }
395                    }
396                    let number_added = buffer.len() - initial_length;
397                    if number_added > 0 {
398                        self.inner.semaphore.add_permits(number_added);
399                        coop.made_progress();
400                        return Ready(number_added);
401                    }
402                };
403            }
404
405            try_recv!();
406
407            self.inner.rx_waker.register_by_ref(cx.waker());
408
409            // It is possible that a value was pushed between attempting to read
410            // and registering the task, so we have to check the channel a
411            // second time here.
412            try_recv!();
413
414            if rx_fields.rx_closed && self.inner.semaphore.is_idle() {
415                assert!(buffer.is_empty());
416                coop.made_progress();
417                Ready(0usize)
418            } else {
419                Pending
420            }
421        })
422    }
423
424    /// Try to receive the next value.
425    pub(crate) fn try_recv(&mut self) -> Result<T, TryRecvError> {
426        use super::list::TryPopResult;
427
428        self.inner.rx_fields.with_mut(|rx_fields_ptr| {
429            let rx_fields = unsafe { &mut *rx_fields_ptr };
430
431            macro_rules! try_recv {
432                () => {
433                    match rx_fields.list.try_pop(&self.inner.tx) {
434                        TryPopResult::Ok(value) => {
435                            self.inner.semaphore.add_permit();
436                            return Ok(value);
437                        }
438                        TryPopResult::Closed => return Err(TryRecvError::Disconnected),
439                        TryPopResult::Empty => return Err(TryRecvError::Empty),
440                        TryPopResult::Busy => {} // fall through
441                    }
442                };
443            }
444
445            try_recv!();
446
447            // If a previous `poll_recv` call has set a waker, we wake it here.
448            // This allows us to put our own CachedParkThread waker in the
449            // AtomicWaker slot instead.
450            //
451            // This is not a spurious wakeup to `poll_recv` since we just got a
452            // Busy from `try_pop`, which only happens if there are messages in
453            // the queue.
454            self.inner.rx_waker.wake();
455
456            // Park the thread until the problematic send has completed.
457            let mut park = CachedParkThread::new();
458            let waker = park.waker().unwrap();
459            loop {
460                self.inner.rx_waker.register_by_ref(&waker);
461                // It is possible that the problematic send has now completed,
462                // so we have to check for messages again.
463                try_recv!();
464                park.park();
465            }
466        })
467    }
468
469    pub(super) fn semaphore(&self) -> &S {
470        &self.inner.semaphore
471    }
472}
473
474impl<T, S: Semaphore> Drop for Rx<T, S> {
475    fn drop(&mut self) {
476        use super::block::Read::Value;
477
478        self.close();
479
480        self.inner.rx_fields.with_mut(|rx_fields_ptr| {
481            let rx_fields = unsafe { &mut *rx_fields_ptr };
482
483            while let Some(Value(_)) = rx_fields.list.pop(&self.inner.tx) {
484                self.inner.semaphore.add_permit();
485            }
486        });
487    }
488}
489
490// ===== impl Chan =====
491
492impl<T, S> Chan<T, S> {
493    fn send(&self, value: T) {
494        // Push the value
495        self.tx.push(value);
496
497        // Notify the rx task
498        self.rx_waker.wake();
499    }
500
501    pub(super) fn decrement_weak_count(&self) {
502        self.tx_weak_count.fetch_sub(1, Relaxed);
503    }
504
505    pub(super) fn increment_weak_count(&self) {
506        self.tx_weak_count.fetch_add(1, Relaxed);
507    }
508
509    pub(super) fn strong_count(&self) -> usize {
510        self.tx_count.load(Acquire)
511    }
512
513    pub(super) fn weak_count(&self) -> usize {
514        self.tx_weak_count.load(Relaxed)
515    }
516}
517
518impl<T, S> Drop for Chan<T, S> {
519    fn drop(&mut self) {
520        use super::block::Read::Value;
521
522        // Safety: the only owner of the rx fields is Chan, and being
523        // inside its own Drop means we're the last ones to touch it.
524        self.rx_fields.with_mut(|rx_fields_ptr| {
525            let rx_fields = unsafe { &mut *rx_fields_ptr };
526
527            while let Some(Value(_)) = rx_fields.list.pop(&self.tx) {}
528            unsafe { rx_fields.list.free_blocks() };
529        });
530    }
531}
532
533// ===== impl Semaphore for (::Semaphore, capacity) =====
534
535impl Semaphore for bounded::Semaphore {
536    fn add_permit(&self) {
537        self.semaphore.release(1);
538    }
539
540    fn add_permits(&self, n: usize) {
541        self.semaphore.release(n)
542    }
543
544    fn is_idle(&self) -> bool {
545        self.semaphore.available_permits() == self.bound
546    }
547
548    fn close(&self) {
549        self.semaphore.close();
550    }
551
552    fn is_closed(&self) -> bool {
553        self.semaphore.is_closed()
554    }
555}
556
557// ===== impl Semaphore for AtomicUsize =====
558
559impl Semaphore for unbounded::Semaphore {
560    fn add_permit(&self) {
561        let prev = self.0.fetch_sub(2, Release);
562
563        if prev >> 1 == 0 {
564            // Something went wrong
565            process::abort();
566        }
567    }
568
569    fn add_permits(&self, n: usize) {
570        let prev = self.0.fetch_sub(n << 1, Release);
571
572        if (prev >> 1) < n {
573            // Something went wrong
574            process::abort();
575        }
576    }
577
578    fn is_idle(&self) -> bool {
579        self.0.load(Acquire) >> 1 == 0
580    }
581
582    fn close(&self) {
583        self.0.fetch_or(1, Release);
584    }
585
586    fn is_closed(&self) -> bool {
587        self.0.load(Acquire) & 1 == 1
588    }
589}