1#![cfg_attr(not(feature = "sync"), allow(unreachable_pub, dead_code))]
2use crate::loom::cell::UnsafeCell;
19use crate::loom::sync::atomic::AtomicUsize;
20use crate::loom::sync::{Mutex, MutexGuard};
21use crate::util::linked_list::{self, LinkedList};
22#[cfg(all(tokio_unstable, feature = "tracing"))]
23use crate::util::trace;
24use crate::util::WakeList;
25
26use std::future::Future;
27use std::marker::PhantomPinned;
28use std::pin::Pin;
29use std::ptr::NonNull;
30use std::sync::atomic::Ordering::*;
31use std::task::{Context, Poll, Waker};
32use std::{cmp, fmt};
33
34pub(crate) struct Semaphore {
36 waiters: Mutex<Waitlist>,
37 permits: AtomicUsize,
39 #[cfg(all(tokio_unstable, feature = "tracing"))]
40 resource_span: tracing::Span,
41}
42
43struct Waitlist {
44 queue: LinkedList<Waiter, <Waiter as linked_list::Link>::Target>,
45 closed: bool,
46}
47
48#[derive(Debug, PartialEq, Eq)]
52pub enum TryAcquireError {
53 Closed,
57
58 NoPermits,
60}
61#[derive(Debug)]
69pub struct AcquireError(());
70
71pub(crate) struct Acquire<'a> {
72 node: Waiter,
73 semaphore: &'a Semaphore,
74 num_permits: usize,
75 queued: bool,
76}
77
78struct Waiter {
80 state: AtomicUsize,
85
86 waker: UnsafeCell<Option<Waker>>,
92
93 pointers: linked_list::Pointers<Waiter>,
106
107 #[cfg(all(tokio_unstable, feature = "tracing"))]
108 ctx: trace::AsyncOpTracingCtx,
109
110 _p: PhantomPinned,
112}
113
114generate_addr_of_methods! {
115 impl<> Waiter {
116 unsafe fn addr_of_pointers(self: NonNull<Self>) -> NonNull<linked_list::Pointers<Waiter>> {
117 &self.pointers
118 }
119 }
120}
121
122impl Semaphore {
123 pub(crate) const MAX_PERMITS: usize = usize::MAX >> 3;
131 const CLOSED: usize = 1;
132 const PERMIT_SHIFT: usize = 1;
136
137 pub(crate) fn new(permits: usize) -> Self {
141 assert!(
142 permits <= Self::MAX_PERMITS,
143 "a semaphore may not have more than MAX_PERMITS permits ({})",
144 Self::MAX_PERMITS
145 );
146
147 #[cfg(all(tokio_unstable, feature = "tracing"))]
148 let resource_span = {
149 let resource_span = tracing::trace_span!(
150 "runtime.resource",
151 concrete_type = "Semaphore",
152 kind = "Sync",
153 is_internal = true
154 );
155
156 resource_span.in_scope(|| {
157 tracing::trace!(
158 target: "runtime::resource::state_update",
159 permits = permits,
160 permits.op = "override",
161 )
162 });
163 resource_span
164 };
165
166 Self {
167 permits: AtomicUsize::new(permits << Self::PERMIT_SHIFT),
168 waiters: Mutex::new(Waitlist {
169 queue: LinkedList::new(),
170 closed: false,
171 }),
172 #[cfg(all(tokio_unstable, feature = "tracing"))]
173 resource_span,
174 }
175 }
176
177 #[cfg(not(all(loom, test)))]
181 pub(crate) const fn const_new(permits: usize) -> Self {
182 assert!(permits <= Self::MAX_PERMITS);
183
184 Self {
185 permits: AtomicUsize::new(permits << Self::PERMIT_SHIFT),
186 waiters: Mutex::const_new(Waitlist {
187 queue: LinkedList::new(),
188 closed: false,
189 }),
190 #[cfg(all(tokio_unstable, feature = "tracing"))]
191 resource_span: tracing::Span::none(),
192 }
193 }
194
195 pub(crate) fn new_closed() -> Self {
197 Self {
198 permits: AtomicUsize::new(Self::CLOSED),
199 waiters: Mutex::new(Waitlist {
200 queue: LinkedList::new(),
201 closed: true,
202 }),
203 #[cfg(all(tokio_unstable, feature = "tracing"))]
204 resource_span: tracing::Span::none(),
205 }
206 }
207
208 #[cfg(not(all(loom, test)))]
210 pub(crate) const fn const_new_closed() -> Self {
211 Self {
212 permits: AtomicUsize::new(Self::CLOSED),
213 waiters: Mutex::const_new(Waitlist {
214 queue: LinkedList::new(),
215 closed: true,
216 }),
217 #[cfg(all(tokio_unstable, feature = "tracing"))]
218 resource_span: tracing::Span::none(),
219 }
220 }
221
222 pub(crate) fn available_permits(&self) -> usize {
224 self.permits.load(Acquire) >> Self::PERMIT_SHIFT
225 }
226
227 pub(crate) fn release(&self, added: usize) {
231 if added == 0 {
232 return;
233 }
234
235 self.add_permits_locked(added, self.waiters.lock());
237 }
238
239 pub(crate) fn close(&self) {
242 let mut waiters = self.waiters.lock();
243 self.permits.fetch_or(Self::CLOSED, Release);
251 waiters.closed = true;
252 while let Some(mut waiter) = waiters.queue.pop_back() {
253 let waker = unsafe { waiter.as_mut().waker.with_mut(|waker| (*waker).take()) };
254 if let Some(waker) = waker {
255 waker.wake();
256 }
257 }
258 }
259
260 pub(crate) fn is_closed(&self) -> bool {
262 self.permits.load(Acquire) & Self::CLOSED == Self::CLOSED
263 }
264
265 pub(crate) fn try_acquire(&self, num_permits: usize) -> Result<(), TryAcquireError> {
266 assert!(
267 num_permits <= Self::MAX_PERMITS,
268 "a semaphore may not have more than MAX_PERMITS permits ({})",
269 Self::MAX_PERMITS
270 );
271 let num_permits = num_permits << Self::PERMIT_SHIFT;
272 let mut curr = self.permits.load(Acquire);
273 loop {
274 if curr & Self::CLOSED == Self::CLOSED {
276 return Err(TryAcquireError::Closed);
277 }
278
279 if curr < num_permits {
281 return Err(TryAcquireError::NoPermits);
282 }
283
284 let next = curr - num_permits;
285
286 match self.permits.compare_exchange(curr, next, AcqRel, Acquire) {
287 Ok(_) => {
288 return Ok(());
290 }
291 Err(actual) => curr = actual,
292 }
293 }
294 }
295
296 pub(crate) fn acquire(&self, num_permits: usize) -> Acquire<'_> {
297 Acquire::new(self, num_permits)
298 }
299
300 fn add_permits_locked(&self, mut rem: usize, waiters: MutexGuard<'_, Waitlist>) {
306 let mut wakers = WakeList::new();
307 let mut lock = Some(waiters);
308 let mut is_empty = false;
309 while rem > 0 {
310 let mut waiters = lock.take().unwrap_or_else(|| self.waiters.lock());
311 'inner: while wakers.can_push() {
312 match waiters.queue.last() {
314 Some(waiter) => {
315 if !waiter.assign_permits(&mut rem) {
316 break 'inner;
317 }
318 }
319 None => {
320 is_empty = true;
321 break 'inner;
324 }
325 };
326 let mut waiter = waiters.queue.pop_back().unwrap();
327 if let Some(waker) =
328 unsafe { waiter.as_mut().waker.with_mut(|waker| (*waker).take()) }
329 {
330 wakers.push(waker);
331 }
332 }
333
334 if rem > 0 && is_empty {
335 let permits = rem;
336 assert!(
337 permits <= Self::MAX_PERMITS,
338 "cannot add more than MAX_PERMITS permits ({})",
339 Self::MAX_PERMITS
340 );
341 let prev = self.permits.fetch_add(rem << Self::PERMIT_SHIFT, Release);
342 let prev = prev >> Self::PERMIT_SHIFT;
343 assert!(
344 prev + permits <= Self::MAX_PERMITS,
345 "number of added permits ({}) would overflow MAX_PERMITS ({})",
346 rem,
347 Self::MAX_PERMITS
348 );
349
350 #[cfg(all(tokio_unstable, feature = "tracing"))]
352 self.resource_span.in_scope(|| {
353 tracing::trace!(
354 target: "runtime::resource::state_update",
355 permits = rem,
356 permits.op = "add",
357 )
358 });
359
360 rem = 0;
361 }
362
363 drop(waiters); wakers.wake_all();
366 }
367
368 assert_eq!(rem, 0);
369 }
370
371 pub(crate) fn forget_permits(&self, n: usize) -> usize {
376 if n == 0 {
377 return 0;
378 }
379
380 let mut curr_bits = self.permits.load(Acquire);
381 loop {
382 let curr = curr_bits >> Self::PERMIT_SHIFT;
383 let new = curr.saturating_sub(n);
384 match self.permits.compare_exchange_weak(
385 curr_bits,
386 new << Self::PERMIT_SHIFT,
387 AcqRel,
388 Acquire,
389 ) {
390 Ok(_) => return std::cmp::min(curr, n),
391 Err(actual) => curr_bits = actual,
392 };
393 }
394 }
395
396 fn poll_acquire(
397 &self,
398 cx: &mut Context<'_>,
399 num_permits: usize,
400 node: Pin<&mut Waiter>,
401 queued: bool,
402 ) -> Poll<Result<(), AcquireError>> {
403 let mut acquired = 0;
404
405 let needed = if queued {
406 node.state.load(Acquire) << Self::PERMIT_SHIFT
407 } else {
408 num_permits << Self::PERMIT_SHIFT
409 };
410
411 let mut lock = None;
412 let mut curr = self.permits.load(Acquire);
415 let mut waiters = loop {
416 if curr & Self::CLOSED > 0 {
418 return Poll::Ready(Err(AcquireError::closed()));
419 }
420
421 let mut remaining = 0;
422 let total = curr
423 .checked_add(acquired)
424 .expect("number of permits must not overflow");
425 let (next, acq) = if total >= needed {
426 let next = curr - (needed - acquired);
427 (next, needed >> Self::PERMIT_SHIFT)
428 } else {
429 remaining = (needed - acquired) - curr;
430 (0, curr >> Self::PERMIT_SHIFT)
431 };
432
433 if remaining > 0 && lock.is_none() {
434 lock = Some(self.waiters.lock());
442 }
443
444 match self.permits.compare_exchange(curr, next, AcqRel, Acquire) {
445 Ok(_) => {
446 acquired += acq;
447 if remaining == 0 {
448 if !queued {
449 #[cfg(all(tokio_unstable, feature = "tracing"))]
450 self.resource_span.in_scope(|| {
451 tracing::trace!(
452 target: "runtime::resource::state_update",
453 permits = acquired,
454 permits.op = "sub",
455 );
456 tracing::trace!(
457 target: "runtime::resource::async_op::state_update",
458 permits_obtained = acquired,
459 permits.op = "add",
460 )
461 });
462
463 return Poll::Ready(Ok(()));
464 } else if lock.is_none() {
465 break self.waiters.lock();
466 }
467 }
468 break lock.expect("lock must be acquired before waiting");
469 }
470 Err(actual) => curr = actual,
471 }
472 };
473
474 if waiters.closed {
475 return Poll::Ready(Err(AcquireError::closed()));
476 }
477
478 #[cfg(all(tokio_unstable, feature = "tracing"))]
479 self.resource_span.in_scope(|| {
480 tracing::trace!(
481 target: "runtime::resource::state_update",
482 permits = acquired,
483 permits.op = "sub",
484 )
485 });
486
487 if node.assign_permits(&mut acquired) {
488 self.add_permits_locked(acquired, waiters);
489 return Poll::Ready(Ok(()));
490 }
491
492 assert_eq!(acquired, 0);
493 let mut old_waker = None;
494
495 node.waker.with_mut(|waker| {
497 let waker = unsafe { &mut *waker };
499 if waker
501 .as_ref()
502 .map_or(true, |waker| !waker.will_wake(cx.waker()))
503 {
504 old_waker = std::mem::replace(waker, Some(cx.waker().clone()));
505 }
506 });
507
508 if !queued {
510 let node = unsafe {
511 let node = Pin::into_inner_unchecked(node) as *mut _;
512 NonNull::new_unchecked(node)
513 };
514
515 waiters.queue.push_front(node);
516 }
517 drop(waiters);
518 drop(old_waker);
519
520 Poll::Pending
521 }
522}
523
524impl fmt::Debug for Semaphore {
525 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
526 fmt.debug_struct("Semaphore")
527 .field("permits", &self.available_permits())
528 .finish()
529 }
530}
531
532impl Waiter {
533 fn new(
534 num_permits: usize,
535 #[cfg(all(tokio_unstable, feature = "tracing"))] ctx: trace::AsyncOpTracingCtx,
536 ) -> Self {
537 Waiter {
538 waker: UnsafeCell::new(None),
539 state: AtomicUsize::new(num_permits),
540 pointers: linked_list::Pointers::new(),
541 #[cfg(all(tokio_unstable, feature = "tracing"))]
542 ctx,
543 _p: PhantomPinned,
544 }
545 }
546
547 fn assign_permits(&self, n: &mut usize) -> bool {
551 let mut curr = self.state.load(Acquire);
552 loop {
553 let assign = cmp::min(curr, *n);
554 let next = curr - assign;
555 match self.state.compare_exchange(curr, next, AcqRel, Acquire) {
556 Ok(_) => {
557 *n -= assign;
558 #[cfg(all(tokio_unstable, feature = "tracing"))]
559 self.ctx.async_op_span.in_scope(|| {
560 tracing::trace!(
561 target: "runtime::resource::async_op::state_update",
562 permits_obtained = assign,
563 permits.op = "add",
564 );
565 });
566 return next == 0;
567 }
568 Err(actual) => curr = actual,
569 }
570 }
571 }
572}
573
574impl Future for Acquire<'_> {
575 type Output = Result<(), AcquireError>;
576
577 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
578 ready!(crate::trace::trace_leaf(cx));
579
580 #[cfg(all(tokio_unstable, feature = "tracing"))]
581 let _resource_span = self.node.ctx.resource_span.clone().entered();
582 #[cfg(all(tokio_unstable, feature = "tracing"))]
583 let _async_op_span = self.node.ctx.async_op_span.clone().entered();
584 #[cfg(all(tokio_unstable, feature = "tracing"))]
585 let _async_op_poll_span = self.node.ctx.async_op_poll_span.clone().entered();
586
587 let (node, semaphore, needed, queued) = self.project();
588
589 #[cfg(all(tokio_unstable, feature = "tracing"))]
591 let coop = ready!(trace_poll_op!(
592 "poll_acquire",
593 crate::runtime::coop::poll_proceed(cx),
594 ));
595
596 #[cfg(not(all(tokio_unstable, feature = "tracing")))]
597 let coop = ready!(crate::runtime::coop::poll_proceed(cx));
598
599 let result = match semaphore.poll_acquire(cx, needed, node, *queued) {
600 Poll::Pending => {
601 *queued = true;
602 Poll::Pending
603 }
604 Poll::Ready(r) => {
605 coop.made_progress();
606 r?;
607 *queued = false;
608 Poll::Ready(Ok(()))
609 }
610 };
611
612 #[cfg(all(tokio_unstable, feature = "tracing"))]
613 return trace_poll_op!("poll_acquire", result);
614
615 #[cfg(not(all(tokio_unstable, feature = "tracing")))]
616 return result;
617 }
618}
619
620impl<'a> Acquire<'a> {
621 fn new(semaphore: &'a Semaphore, num_permits: usize) -> Self {
622 #[cfg(any(not(tokio_unstable), not(feature = "tracing")))]
623 return Self {
624 node: Waiter::new(num_permits),
625 semaphore,
626 num_permits,
627 queued: false,
628 };
629
630 #[cfg(all(tokio_unstable, feature = "tracing"))]
631 return semaphore.resource_span.in_scope(|| {
632 let async_op_span =
633 tracing::trace_span!("runtime.resource.async_op", source = "Acquire::new");
634 let async_op_poll_span = async_op_span.in_scope(|| {
635 tracing::trace!(
636 target: "runtime::resource::async_op::state_update",
637 permits_requested = num_permits,
638 permits.op = "override",
639 );
640
641 tracing::trace!(
642 target: "runtime::resource::async_op::state_update",
643 permits_obtained = 0usize,
644 permits.op = "override",
645 );
646
647 tracing::trace_span!("runtime.resource.async_op.poll")
648 });
649
650 let ctx = trace::AsyncOpTracingCtx {
651 async_op_span,
652 async_op_poll_span,
653 resource_span: semaphore.resource_span.clone(),
654 };
655
656 Self {
657 node: Waiter::new(num_permits, ctx),
658 semaphore,
659 num_permits,
660 queued: false,
661 }
662 });
663 }
664
665 fn project(self: Pin<&mut Self>) -> (Pin<&mut Waiter>, &Semaphore, usize, &mut bool) {
666 fn is_unpin<T: Unpin>() {}
667 unsafe {
668 is_unpin::<&Semaphore>();
671 is_unpin::<&mut bool>();
672 is_unpin::<usize>();
673
674 let this = self.get_unchecked_mut();
675 (
676 Pin::new_unchecked(&mut this.node),
677 this.semaphore,
678 this.num_permits,
679 &mut this.queued,
680 )
681 }
682 }
683}
684
685impl Drop for Acquire<'_> {
686 fn drop(&mut self) {
687 if !self.queued {
690 return;
691 }
692
693 let mut waiters = self.semaphore.waiters.lock();
697
698 let node = NonNull::from(&mut self.node);
700 unsafe { waiters.queue.remove(node) };
702
703 let acquired_permits = self.num_permits - self.node.state.load(Acquire);
704 if acquired_permits > 0 {
705 self.semaphore.add_permits_locked(acquired_permits, waiters);
706 }
707 }
708}
709
710unsafe impl Sync for Acquire<'_> {}
716
717impl AcquireError {
720 fn closed() -> AcquireError {
721 AcquireError(())
722 }
723}
724
725impl fmt::Display for AcquireError {
726 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
727 write!(fmt, "semaphore closed")
728 }
729}
730
731impl std::error::Error for AcquireError {}
732
733impl TryAcquireError {
736 #[allow(dead_code)] pub(crate) fn is_closed(&self) -> bool {
739 matches!(self, TryAcquireError::Closed)
740 }
741
742 #[allow(dead_code)] pub(crate) fn is_no_permits(&self) -> bool {
746 matches!(self, TryAcquireError::NoPermits)
747 }
748}
749
750impl fmt::Display for TryAcquireError {
751 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
752 match self {
753 TryAcquireError::Closed => write!(fmt, "semaphore closed"),
754 TryAcquireError::NoPermits => write!(fmt, "no permits available"),
755 }
756 }
757}
758
759impl std::error::Error for TryAcquireError {}
760
761unsafe impl linked_list::Link for Waiter {
765 type Handle = NonNull<Waiter>;
766 type Target = Waiter;
767
768 fn as_raw(handle: &Self::Handle) -> NonNull<Waiter> {
769 *handle
770 }
771
772 unsafe fn from_raw(ptr: NonNull<Waiter>) -> NonNull<Waiter> {
773 ptr
774 }
775
776 unsafe fn pointers(target: NonNull<Waiter>) -> NonNull<linked_list::Pointers<Waiter>> {
777 Waiter::addr_of_pointers(target)
778 }
779}