want/
lib.rs

1#![doc(html_root_url = "https://docs.rs/want/0.3.0")]
2#![deny(warnings)]
3#![deny(missing_docs)]
4#![deny(missing_debug_implementations)]
5
6//! A Futures channel-like utility to signal when a value is wanted.
7//!
8//! Futures are supposed to be lazy, and only starting work if `Future::poll`
9//! is called. The same is true of `Stream`s, but when using a channel as
10//! a `Stream`, it can be hard to know if the receiver is ready for the next
11//! value.
12//!
13//! Put another way, given a `(tx, rx)` from `futures::sync::mpsc::channel()`,
14//! how can the sender (`tx`) know when the receiver (`rx`) actually wants more
15//! work to be produced? Just because there is room in the channel buffer
16//! doesn't mean the work would be used by the receiver.
17//!
18//! This is where something like `want` comes in. Added to a channel, you can
19//! make sure that the `tx` only creates the message and sends it when the `rx`
20//! has `poll()` for it, and the buffer was empty.
21//!
22//! # Example
23//!
24//! ```nightly
25//! # //#![feature(async_await)]
26//! extern crate want;
27//!
28//! # fn spawn<T>(_t: T) {}
29//! # fn we_still_want_message() -> bool { true }
30//! # fn mpsc_channel() -> (Tx, Rx) { (Tx, Rx) }
31//! # struct Tx;
32//! # impl Tx { fn send<T>(&mut self, _: T) {} }
33//! # struct Rx;
34//! # impl Rx { async fn recv(&mut self) -> Option<Expensive> { Some(Expensive) } }
35//!
36//! // Some message that is expensive to produce.
37//! struct Expensive;
38//!
39//! // Some futures-aware MPSC channel...
40//! let (mut tx, mut rx) = mpsc_channel();
41//!
42//! // And our `want` channel!
43//! let (mut gv, mut tk) = want::new();
44//!
45//!
46//! // Our receiving task...
47//! spawn(async move {
48//!     // Maybe something comes up that prevents us from ever
49//!     // using the expensive message.
50//!     //
51//!     // Without `want`, the "send" task may have started to
52//!     // produce the expensive message even though we wouldn't
53//!     // be able to use it.
54//!     if !we_still_want_message() {
55//!         return;
56//!     }
57//!
58//!     // But we can use it! So tell the `want` channel.
59//!     tk.want();
60//!
61//!     match rx.recv().await {
62//!         Some(_msg) => println!("got a message"),
63//!         None => println!("DONE"),
64//!     }
65//! });
66//!
67//! // Our sending task
68//! spawn(async move {
69//!     // It's expensive to create a new message, so we wait until the
70//!     // receiving end truly *wants* the message.
71//!     if let Err(_closed) = gv.want().await {
72//!         // Looks like they will never want it...
73//!         return;
74//!     }
75//!
76//!     // They want it, let's go!
77//!     tx.send(Expensive);
78//! });
79//!
80//! # fn main() {}
81//! ```
82
83#[macro_use]
84extern crate log;
85
86use std::fmt;
87use std::future::Future;
88use std::mem;
89use std::pin::Pin;
90use std::sync::Arc;
91use std::sync::atomic::AtomicUsize;
92// SeqCst is the only ordering used to ensure accessing the state and
93// TryLock are never re-ordered.
94use std::sync::atomic::Ordering::SeqCst;
95use std::task::{self, Poll, Waker};
96
97
98use try_lock::TryLock;
99
100/// Create a new `want` channel.
101pub fn new() -> (Giver, Taker) {
102    let inner = Arc::new(Inner {
103        state: AtomicUsize::new(State::Idle.into()),
104        task: TryLock::new(None),
105    });
106    let inner2 = inner.clone();
107    (
108        Giver {
109            inner: inner,
110        },
111        Taker {
112            inner: inner2,
113        },
114    )
115}
116
117/// An entity that gives a value when wanted.
118pub struct Giver {
119    inner: Arc<Inner>,
120}
121
122/// An entity that wants a value.
123pub struct Taker {
124    inner: Arc<Inner>,
125}
126
127/// A cloneable `Giver`.
128///
129/// It differs from `Giver` in that you cannot poll for `want`. It's only
130/// usable as a cancellation watcher.
131#[derive(Clone)]
132pub struct SharedGiver {
133    inner: Arc<Inner>,
134}
135
136/// The `Taker` has canceled its interest in a value.
137pub struct Closed {
138    _inner: (),
139}
140
141#[derive(Clone, Copy, Debug)]
142enum State {
143    Idle,
144    Want,
145    Give,
146    Closed,
147}
148
149impl From<State> for usize {
150    fn from(s: State) -> usize {
151        match s {
152            State::Idle => 0,
153            State::Want => 1,
154            State::Give => 2,
155            State::Closed => 3,
156        }
157    }
158}
159
160impl From<usize> for State {
161    fn from(num: usize) -> State {
162        match num {
163            0 => State::Idle,
164            1 => State::Want,
165            2 => State::Give,
166            3 => State::Closed,
167            _ => unreachable!("unknown state: {}", num),
168        }
169    }
170}
171
172struct Inner {
173    state: AtomicUsize,
174    task: TryLock<Option<Waker>>,
175}
176
177// ===== impl Giver ======
178
179impl Giver {
180    /// Returns a `Future` that fulfills when the `Taker` has done some action.
181    pub fn want<'a>(&'a mut self) -> impl Future<Output = Result<(), Closed>> + 'a {
182        Want(self)
183    }
184
185    /// Poll whether the `Taker` has registered interest in another value.
186    ///
187    /// - If the `Taker` has called `want()`, this returns `Async::Ready(())`.
188    /// - If the `Taker` has not called `want()` since last poll, this
189    ///   returns `Async::NotReady`, and parks the current task to be notified
190    ///   when the `Taker` does call `want()`.
191    /// - If the `Taker` has canceled (or dropped), this returns `Closed`.
192    ///
193    /// After knowing that the Taker is wanting, the state can be reset by
194    /// calling [`give`](Giver::give).
195    pub fn poll_want(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Closed>> {
196        loop {
197            let state = self.inner.state.load(SeqCst).into();
198            match state {
199                State::Want => {
200                    trace!("poll_want: taker wants!");
201                    return Poll::Ready(Ok(()));
202                },
203                State::Closed => {
204                    trace!("poll_want: closed");
205                    return Poll::Ready(Err(Closed { _inner: () }));
206                },
207                State::Idle | State::Give => {
208                    // Taker doesn't want anything yet, so park.
209                    if let Some(mut locked) = self.inner.task.try_lock_order(SeqCst, SeqCst) {
210
211                        // While we have the lock, try to set to GIVE.
212                        let old = self.inner.state.compare_and_swap(
213                            state.into(),
214                            State::Give.into(),
215                            SeqCst,
216                        );
217                        // If it's still the first state (Idle or Give), park current task.
218                        if old == state.into() {
219                            let park = locked.as_ref()
220                                .map(|w| !w.will_wake(cx.waker()))
221                                .unwrap_or(true);
222                            if park {
223                                let old = mem::replace(&mut *locked, Some(cx.waker().clone()));
224                                drop(locked);
225                                old.map(|prev_task| {
226                                    // there was an old task parked here.
227                                    // it might be waiting to be notified,
228                                    // so poke it before dropping.
229                                    prev_task.wake();
230                                });
231                            }
232                            return Poll::Pending;
233                        }
234                        // Otherwise, something happened! Go around the loop again.
235                    } else {
236                        // if we couldn't take the lock, then a Taker has it.
237                        // The *ONLY* reason is because it is in the process of notifying us
238                        // of its want.
239                        //
240                        // We need to loop again to see what state it was changed to.
241                    }
242                },
243            }
244        }
245    }
246
247    /// Mark the state as idle, if the Taker currently is wanting.
248    ///
249    /// Returns true if Taker was wanting, false otherwise.
250    #[inline]
251    pub fn give(&self) -> bool {
252        // only set to IDLE if it is still Want
253        self.inner.state.compare_and_swap(
254            State::Want.into(),
255            State::Idle.into(),
256            SeqCst,
257        ) == State::Want.into()
258    }
259
260    /// Check if the `Taker` has called `want()` without parking a task.
261    ///
262    /// This is safe to call outside of a futures task context, but other
263    /// means of being notified is left to the user.
264    #[inline]
265    pub fn is_wanting(&self) -> bool {
266        self.inner.state.load(SeqCst) == State::Want.into()
267    }
268
269
270    /// Check if the `Taker` has canceled interest without parking a task.
271    #[inline]
272    pub fn is_canceled(&self) -> bool {
273        self.inner.state.load(SeqCst) == State::Closed.into()
274    }
275
276    /// Converts this into a `SharedGiver`.
277    #[inline]
278    pub fn shared(self) -> SharedGiver {
279        SharedGiver {
280            inner: self.inner,
281        }
282    }
283}
284
285impl fmt::Debug for Giver {
286    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
287        f.debug_struct("Giver")
288            .field("state", &self.inner.state())
289            .finish()
290    }
291}
292
293// ===== impl SharedGiver ======
294
295impl SharedGiver {
296    /// Check if the `Taker` has called `want()` without parking a task.
297    ///
298    /// This is safe to call outside of a futures task context, but other
299    /// means of being notified is left to the user.
300    #[inline]
301    pub fn is_wanting(&self) -> bool {
302        self.inner.state.load(SeqCst) == State::Want.into()
303    }
304
305
306    /// Check if the `Taker` has canceled interest without parking a task.
307    #[inline]
308    pub fn is_canceled(&self) -> bool {
309        self.inner.state.load(SeqCst) == State::Closed.into()
310    }
311}
312
313impl fmt::Debug for SharedGiver {
314    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
315        f.debug_struct("SharedGiver")
316            .field("state", &self.inner.state())
317            .finish()
318    }
319}
320
321// ===== impl Taker ======
322
323impl Taker {
324    /// Signal to the `Giver` that the want is canceled.
325    ///
326    /// This is useful to tell that the channel is closed if you cannot
327    /// drop the value yet.
328    #[inline]
329    pub fn cancel(&mut self) {
330        trace!("signal: {:?}", State::Closed);
331        self.signal(State::Closed)
332    }
333
334    /// Signal to the `Giver` that a value is wanted.
335    #[inline]
336    pub fn want(&mut self) {
337        debug_assert!(
338            self.inner.state.load(SeqCst) != State::Closed.into(),
339            "want called after cancel"
340        );
341        trace!("signal: {:?}", State::Want);
342        self.signal(State::Want)
343    }
344
345    #[inline]
346    fn signal(&mut self, state: State) {
347        let old_state = self.inner.state.swap(state.into(), SeqCst).into();
348        match old_state {
349            State::Idle | State::Want | State::Closed => (),
350            State::Give => {
351                loop {
352                    if let Some(mut locked) = self.inner.task.try_lock_order(SeqCst, SeqCst) {
353                        if let Some(task) = locked.take() {
354                            drop(locked);
355                            trace!("signal found waiting giver, notifying");
356                            task.wake();
357                        }
358                        return;
359                    } else {
360                        // if we couldn't take the lock, then a Giver has it.
361                        // The *ONLY* reason is because it is in the process of parking.
362                        //
363                        // We need to loop and take the lock so we can notify this task.
364                    }
365                }
366            },
367        }
368    }
369}
370
371impl Drop for Taker {
372    #[inline]
373    fn drop(&mut self) {
374        self.signal(State::Closed);
375    }
376}
377
378impl fmt::Debug for Taker {
379    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
380        f.debug_struct("Taker")
381            .field("state", &self.inner.state())
382            .finish()
383    }
384}
385
386// ===== impl Closed ======
387
388impl fmt::Debug for Closed {
389    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
390        f.debug_struct("Closed")
391            .finish()
392    }
393}
394
395// ===== impl Inner ======
396
397impl Inner {
398    #[inline]
399    fn state(&self) -> State {
400        self.state.load(SeqCst).into()
401    }
402}
403
404// ===== impl PollFn ======
405
406struct Want<'a>(&'a mut Giver);
407
408
409impl Future for Want<'_> {
410    type Output = Result<(), Closed>;
411
412    fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
413        self.0.poll_want(cx)
414    }
415}
416
417#[cfg(test)]
418mod tests {
419    use std::thread;
420    use tokio_sync::oneshot;
421    use super::*;
422
423    fn block_on<F: Future>(f: F) -> F::Output {
424        tokio_executor::enter()
425            .expect("block_on enter")
426            .block_on(f)
427    }
428
429    #[test]
430    fn want_ready() {
431        let (mut gv, mut tk) = new();
432
433        tk.want();
434
435        block_on(gv.want()).unwrap();
436    }
437
438    #[test]
439    fn want_notify_0() {
440        let (mut gv, mut tk) = new();
441        let (tx, rx) = oneshot::channel();
442
443        thread::spawn(move || {
444            tk.want();
445            // use a oneshot to keep this thread alive
446            // until other thread was notified of want
447            block_on(rx).expect("rx");
448        });
449
450        block_on(gv.want()).expect("want");
451
452        assert!(gv.is_wanting(), "still wanting after poll_want success");
453        assert!(gv.give(), "give is true when wanting");
454
455        assert!(!gv.is_wanting(), "no longer wanting after give");
456        assert!(!gv.is_canceled(), "give doesn't cancel");
457
458        assert!(!gv.give(), "give is false if not wanting");
459
460        tx.send(()).expect("tx");
461    }
462
463    /*
464    /// This tests that if the Giver moves tasks after parking,
465    /// it will still wake up the correct task.
466    #[test]
467    fn want_notify_moving_tasks() {
468        use std::sync::Arc;
469        use futures::executor::{spawn, Notify, NotifyHandle};
470
471        struct WantNotify;
472
473        impl Notify for WantNotify {
474            fn notify(&self, _id: usize) {
475            }
476        }
477
478        fn n() -> NotifyHandle {
479            Arc::new(WantNotify).into()
480        }
481
482        let (mut gv, mut tk) = new();
483
484        let mut s = spawn(poll_fn(move || {
485            gv.poll_want()
486        }));
487
488        // Register with t1 as the task::current()
489        let t1 = n();
490        assert!(s.poll_future_notify(&t1, 1).unwrap().is_not_ready());
491
492        thread::spawn(move || {
493            thread::sleep(::std::time::Duration::from_millis(100));
494            tk.want();
495        });
496
497        // And now, move to a ThreadNotify task.
498        s.into_inner().wait().expect("poll_want");
499    }
500    */
501
502    #[test]
503    fn cancel() {
504        // explicit
505        let (mut gv, mut tk) = new();
506
507        assert!(!gv.is_canceled());
508
509        tk.cancel();
510
511        assert!(gv.is_canceled());
512        block_on(gv.want()).unwrap_err();
513
514        // implicit
515        let (mut gv, tk) = new();
516
517        assert!(!gv.is_canceled());
518
519        drop(tk);
520
521        assert!(gv.is_canceled());
522        block_on(gv.want()).unwrap_err();
523
524        // notifies
525        let (mut gv, tk) = new();
526
527        thread::spawn(move || {
528            let _tk = tk;
529            // and dropped
530        });
531
532        block_on(gv.want()).unwrap_err();
533    }
534
535    /*
536    #[test]
537    fn stress() {
538        let nthreads = 5;
539        let nwants = 100;
540
541        for _ in 0..nthreads {
542            let (mut gv, mut tk) = new();
543            let (mut tx, mut rx) = mpsc::channel(0);
544
545            // rx thread
546            thread::spawn(move || {
547                let mut cnt = 0;
548                poll_fn(move || {
549                    while cnt < nwants {
550                        let n = match rx.poll().expect("rx poll") {
551                            Async::Ready(n) => n.expect("rx opt"),
552                            Async::NotReady => {
553                                tk.want();
554                                return Ok(Async::NotReady);
555                            },
556                        };
557                        assert_eq!(cnt, n);
558                        cnt += 1;
559                    }
560                    Ok::<_, ()>(Async::Ready(()))
561                }).wait().expect("rx wait");
562            });
563
564            // tx thread
565            thread::spawn(move || {
566                let mut cnt = 0;
567                let nsent = poll_fn(move || {
568                    loop {
569                        while let Ok(()) = tx.try_send(cnt) {
570                            cnt += 1;
571                        }
572                        match gv.poll_want() {
573                            Ok(Async::Ready(_)) => (),
574                            Ok(Async::NotReady) => return Ok::<_, ()>(Async::NotReady),
575                            Err(_) => return Ok(Async::Ready(cnt)),
576                        }
577                    }
578                }).wait().expect("tx wait");
579
580                assert_eq!(nsent, nwants);
581            }).join().expect("thread join");
582        }
583    }
584    */
585}