1use crate::loom::cell::UnsafeCell;
2use crate::loom::future::AtomicWaker;
3use crate::loom::sync::atomic::AtomicUsize;
4use crate::loom::sync::Arc;
5use crate::runtime::park::CachedParkThread;
6use crate::sync::mpsc::error::TryRecvError;
7use crate::sync::mpsc::{bounded, list, unbounded};
8use crate::sync::notify::Notify;
9use crate::util::cacheline::CachePadded;
10
11use std::fmt;
12use std::process;
13use std::sync::atomic::Ordering::{AcqRel, Acquire, Relaxed, Release};
14use std::task::Poll::{Pending, Ready};
15use std::task::{Context, Poll};
16
17pub(crate) struct Tx<T, S> {
19 inner: Arc<Chan<T, S>>,
20}
21
22impl<T, S: fmt::Debug> fmt::Debug for Tx<T, S> {
23 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
24 fmt.debug_struct("Tx").field("inner", &self.inner).finish()
25 }
26}
27
28pub(crate) struct Rx<T, S: Semaphore> {
30 inner: Arc<Chan<T, S>>,
31}
32
33impl<T, S: Semaphore + fmt::Debug> fmt::Debug for Rx<T, S> {
34 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
35 fmt.debug_struct("Rx").field("inner", &self.inner).finish()
36 }
37}
38
39pub(crate) trait Semaphore {
40 fn is_idle(&self) -> bool;
41
42 fn add_permit(&self);
43
44 fn add_permits(&self, n: usize);
45
46 fn close(&self);
47
48 fn is_closed(&self) -> bool;
49}
50
51pub(super) struct Chan<T, S> {
52 tx: CachePadded<list::Tx<T>>,
54
55 rx_waker: CachePadded<AtomicWaker>,
57
58 notify_rx_closed: Notify,
60
61 semaphore: S,
63
64 tx_count: AtomicUsize,
68
69 tx_weak_count: AtomicUsize,
71
72 rx_fields: UnsafeCell<RxFields<T>>,
74}
75
76impl<T, S> fmt::Debug for Chan<T, S>
77where
78 S: fmt::Debug,
79{
80 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
81 fmt.debug_struct("Chan")
82 .field("tx", &*self.tx)
83 .field("semaphore", &self.semaphore)
84 .field("rx_waker", &*self.rx_waker)
85 .field("tx_count", &self.tx_count)
86 .field("rx_fields", &"...")
87 .finish()
88 }
89}
90
91struct RxFields<T> {
93 list: list::Rx<T>,
95
96 rx_closed: bool,
98}
99
100impl<T> fmt::Debug for RxFields<T> {
101 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
102 fmt.debug_struct("RxFields")
103 .field("list", &self.list)
104 .field("rx_closed", &self.rx_closed)
105 .finish()
106 }
107}
108
109unsafe impl<T: Send, S: Send> Send for Chan<T, S> {}
110unsafe impl<T: Send, S: Sync> Sync for Chan<T, S> {}
111
112pub(crate) fn channel<T, S: Semaphore>(semaphore: S) -> (Tx<T, S>, Rx<T, S>) {
113 let (tx, rx) = list::channel();
114
115 let chan = Arc::new(Chan {
116 notify_rx_closed: Notify::new(),
117 tx: CachePadded::new(tx),
118 semaphore,
119 rx_waker: CachePadded::new(AtomicWaker::new()),
120 tx_count: AtomicUsize::new(1),
121 tx_weak_count: AtomicUsize::new(0),
122 rx_fields: UnsafeCell::new(RxFields {
123 list: rx,
124 rx_closed: false,
125 }),
126 });
127
128 (Tx::new(chan.clone()), Rx::new(chan))
129}
130
131impl<T, S> Tx<T, S> {
134 fn new(chan: Arc<Chan<T, S>>) -> Tx<T, S> {
135 Tx { inner: chan }
136 }
137
138 pub(super) fn strong_count(&self) -> usize {
139 self.inner.tx_count.load(Acquire)
140 }
141
142 pub(super) fn weak_count(&self) -> usize {
143 self.inner.tx_weak_count.load(Relaxed)
144 }
145
146 pub(super) fn downgrade(&self) -> Arc<Chan<T, S>> {
147 self.inner.increment_weak_count();
148
149 self.inner.clone()
150 }
151
152 pub(super) fn upgrade(chan: Arc<Chan<T, S>>) -> Option<Self> {
154 let mut tx_count = chan.tx_count.load(Acquire);
155
156 loop {
157 if tx_count == 0 {
158 return None;
160 }
161
162 match chan
163 .tx_count
164 .compare_exchange_weak(tx_count, tx_count + 1, AcqRel, Acquire)
165 {
166 Ok(_) => return Some(Tx { inner: chan }),
167 Err(prev_count) => tx_count = prev_count,
168 }
169 }
170 }
171
172 pub(super) fn semaphore(&self) -> &S {
173 &self.inner.semaphore
174 }
175
176 pub(crate) fn send(&self, value: T) {
178 self.inner.send(value);
179 }
180
181 pub(crate) fn wake_rx(&self) {
183 self.inner.rx_waker.wake();
184 }
185
186 pub(crate) fn same_channel(&self, other: &Self) -> bool {
188 Arc::ptr_eq(&self.inner, &other.inner)
189 }
190}
191
192impl<T, S: Semaphore> Tx<T, S> {
193 pub(crate) fn is_closed(&self) -> bool {
194 self.inner.semaphore.is_closed()
195 }
196
197 pub(crate) async fn closed(&self) {
198 let notified = self.inner.notify_rx_closed.notified();
202
203 if self.inner.semaphore.is_closed() {
204 return;
205 }
206 notified.await;
207 }
208}
209
210impl<T, S> Clone for Tx<T, S> {
211 fn clone(&self) -> Tx<T, S> {
212 self.inner.tx_count.fetch_add(1, Relaxed);
215
216 Tx {
217 inner: self.inner.clone(),
218 }
219 }
220}
221
222impl<T, S> Drop for Tx<T, S> {
223 fn drop(&mut self) {
224 if self.inner.tx_count.fetch_sub(1, AcqRel) != 1 {
225 return;
226 }
227
228 self.inner.tx.close();
230
231 self.wake_rx();
233 }
234}
235
236impl<T, S: Semaphore> Rx<T, S> {
239 fn new(chan: Arc<Chan<T, S>>) -> Rx<T, S> {
240 Rx { inner: chan }
241 }
242
243 pub(crate) fn close(&mut self) {
244 self.inner.rx_fields.with_mut(|rx_fields_ptr| {
245 let rx_fields = unsafe { &mut *rx_fields_ptr };
246
247 if rx_fields.rx_closed {
248 return;
249 }
250
251 rx_fields.rx_closed = true;
252 });
253
254 self.inner.semaphore.close();
255 self.inner.notify_rx_closed.notify_waiters();
256 }
257
258 pub(crate) fn is_closed(&self) -> bool {
259 self.inner.semaphore.is_closed() || self.inner.tx_count.load(Acquire) == 0
269 }
270
271 pub(crate) fn is_empty(&self) -> bool {
272 self.inner.rx_fields.with(|rx_fields_ptr| {
273 let rx_fields = unsafe { &*rx_fields_ptr };
274 rx_fields.list.is_empty(&self.inner.tx)
275 })
276 }
277
278 pub(crate) fn len(&self) -> usize {
279 self.inner.rx_fields.with(|rx_fields_ptr| {
280 let rx_fields = unsafe { &*rx_fields_ptr };
281 rx_fields.list.len(&self.inner.tx)
282 })
283 }
284
285 pub(crate) fn recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> {
287 use super::block::Read;
288
289 ready!(crate::trace::trace_leaf(cx));
290
291 let coop = ready!(crate::runtime::coop::poll_proceed(cx));
293
294 self.inner.rx_fields.with_mut(|rx_fields_ptr| {
295 let rx_fields = unsafe { &mut *rx_fields_ptr };
296
297 macro_rules! try_recv {
298 () => {
299 match rx_fields.list.pop(&self.inner.tx) {
300 Some(Read::Value(value)) => {
301 self.inner.semaphore.add_permit();
302 coop.made_progress();
303 return Ready(Some(value));
304 }
305 Some(Read::Closed) => {
306 assert!(self.inner.semaphore.is_idle());
313 coop.made_progress();
314 return Ready(None);
315 }
316 None => {} }
318 };
319 }
320
321 try_recv!();
322
323 self.inner.rx_waker.register_by_ref(cx.waker());
324
325 try_recv!();
329
330 if rx_fields.rx_closed && self.inner.semaphore.is_idle() {
331 coop.made_progress();
332 Ready(None)
333 } else {
334 Pending
335 }
336 })
337 }
338
339 pub(crate) fn recv_many(
344 &mut self,
345 cx: &mut Context<'_>,
346 buffer: &mut Vec<T>,
347 limit: usize,
348 ) -> Poll<usize> {
349 use super::block::Read;
350
351 ready!(crate::trace::trace_leaf(cx));
352
353 let coop = ready!(crate::runtime::coop::poll_proceed(cx));
355
356 if limit == 0 {
357 coop.made_progress();
358 return Ready(0usize);
359 }
360
361 let mut remaining = limit;
362 let initial_length = buffer.len();
363
364 self.inner.rx_fields.with_mut(|rx_fields_ptr| {
365 let rx_fields = unsafe { &mut *rx_fields_ptr };
366 macro_rules! try_recv {
367 () => {
368 while remaining > 0 {
369 match rx_fields.list.pop(&self.inner.tx) {
370 Some(Read::Value(value)) => {
371 remaining -= 1;
372 buffer.push(value);
373 }
374
375 Some(Read::Closed) => {
376 let number_added = buffer.len() - initial_length;
377 if number_added > 0 {
378 self.inner.semaphore.add_permits(number_added);
379 }
380 assert!(self.inner.semaphore.is_idle());
387 coop.made_progress();
388 return Ready(number_added);
389 }
390
391 None => {
392 break; }
394 }
395 }
396 let number_added = buffer.len() - initial_length;
397 if number_added > 0 {
398 self.inner.semaphore.add_permits(number_added);
399 coop.made_progress();
400 return Ready(number_added);
401 }
402 };
403 }
404
405 try_recv!();
406
407 self.inner.rx_waker.register_by_ref(cx.waker());
408
409 try_recv!();
413
414 if rx_fields.rx_closed && self.inner.semaphore.is_idle() {
415 assert!(buffer.is_empty());
416 coop.made_progress();
417 Ready(0usize)
418 } else {
419 Pending
420 }
421 })
422 }
423
424 pub(crate) fn try_recv(&mut self) -> Result<T, TryRecvError> {
426 use super::list::TryPopResult;
427
428 self.inner.rx_fields.with_mut(|rx_fields_ptr| {
429 let rx_fields = unsafe { &mut *rx_fields_ptr };
430
431 macro_rules! try_recv {
432 () => {
433 match rx_fields.list.try_pop(&self.inner.tx) {
434 TryPopResult::Ok(value) => {
435 self.inner.semaphore.add_permit();
436 return Ok(value);
437 }
438 TryPopResult::Closed => return Err(TryRecvError::Disconnected),
439 TryPopResult::Empty => return Err(TryRecvError::Empty),
440 TryPopResult::Busy => {} }
442 };
443 }
444
445 try_recv!();
446
447 self.inner.rx_waker.wake();
455
456 let mut park = CachedParkThread::new();
458 let waker = park.waker().unwrap();
459 loop {
460 self.inner.rx_waker.register_by_ref(&waker);
461 try_recv!();
464 park.park();
465 }
466 })
467 }
468
469 pub(super) fn semaphore(&self) -> &S {
470 &self.inner.semaphore
471 }
472}
473
474impl<T, S: Semaphore> Drop for Rx<T, S> {
475 fn drop(&mut self) {
476 use super::block::Read::Value;
477
478 self.close();
479
480 self.inner.rx_fields.with_mut(|rx_fields_ptr| {
481 let rx_fields = unsafe { &mut *rx_fields_ptr };
482
483 while let Some(Value(_)) = rx_fields.list.pop(&self.inner.tx) {
484 self.inner.semaphore.add_permit();
485 }
486 });
487 }
488}
489
490impl<T, S> Chan<T, S> {
493 fn send(&self, value: T) {
494 self.tx.push(value);
496
497 self.rx_waker.wake();
499 }
500
501 pub(super) fn decrement_weak_count(&self) {
502 self.tx_weak_count.fetch_sub(1, Relaxed);
503 }
504
505 pub(super) fn increment_weak_count(&self) {
506 self.tx_weak_count.fetch_add(1, Relaxed);
507 }
508
509 pub(super) fn strong_count(&self) -> usize {
510 self.tx_count.load(Acquire)
511 }
512
513 pub(super) fn weak_count(&self) -> usize {
514 self.tx_weak_count.load(Relaxed)
515 }
516}
517
518impl<T, S> Drop for Chan<T, S> {
519 fn drop(&mut self) {
520 use super::block::Read::Value;
521
522 self.rx_fields.with_mut(|rx_fields_ptr| {
525 let rx_fields = unsafe { &mut *rx_fields_ptr };
526
527 while let Some(Value(_)) = rx_fields.list.pop(&self.tx) {}
528 unsafe { rx_fields.list.free_blocks() };
529 });
530 }
531}
532
533impl Semaphore for bounded::Semaphore {
536 fn add_permit(&self) {
537 self.semaphore.release(1);
538 }
539
540 fn add_permits(&self, n: usize) {
541 self.semaphore.release(n)
542 }
543
544 fn is_idle(&self) -> bool {
545 self.semaphore.available_permits() == self.bound
546 }
547
548 fn close(&self) {
549 self.semaphore.close();
550 }
551
552 fn is_closed(&self) -> bool {
553 self.semaphore.is_closed()
554 }
555}
556
557impl Semaphore for unbounded::Semaphore {
560 fn add_permit(&self) {
561 let prev = self.0.fetch_sub(2, Release);
562
563 if prev >> 1 == 0 {
564 process::abort();
566 }
567 }
568
569 fn add_permits(&self, n: usize) {
570 let prev = self.0.fetch_sub(n << 1, Release);
571
572 if (prev >> 1) < n {
573 process::abort();
575 }
576 }
577
578 fn is_idle(&self) -> bool {
579 self.0.load(Acquire) >> 1 == 0
580 }
581
582 fn close(&self) {
583 self.0.fetch_or(1, Release);
584 }
585
586 fn is_closed(&self) -> bool {
587 self.0.load(Acquire) & 1 == 1
588 }
589}