fidl_next_protocol/endpoints/
connection.rs

1// Copyright 2025 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use core::future::Future;
6use core::mem::{ManuallyDrop, MaybeUninit, replace, take};
7use core::pin::Pin;
8use core::task::{Context, Poll, Waker};
9
10use fidl_next_codec::EncodeError;
11use pin_project::pin_project;
12
13use crate::concurrency::cell::UnsafeCell;
14use crate::concurrency::future::AtomicWaker;
15use crate::concurrency::hint::unreachable_unchecked;
16use crate::concurrency::sync::Mutex;
17use crate::concurrency::sync::atomic::{AtomicUsize, Ordering};
18use crate::{NonBlockingTransport, ProtocolError, Transport, encode_epitaph, encode_header};
19
20pub const ORDINAL_EPITAPH: u64 = 0xffff_ffff_ffff_ffff;
21
22// Indicates that the connection has been requested to stop. Connections are
23// always stopped as they are terminated.
24const STOPPING_BIT: usize = 1 << 0;
25// Indicates that the connection has been provided a termination reason.
26const TERMINATED_BIT: usize = 1 << 1;
27const BITS_COUNT: usize = 2;
28
29// Each refcount represents a thread which is attempting to access the shared
30// part of the transport.
31const REFCOUNT: usize = 1 << BITS_COUNT;
32
33#[derive(Clone, Copy)]
34struct State(usize);
35
36impl State {
37    fn is_stopping(self) -> bool {
38        self.0 & STOPPING_BIT != 0
39    }
40
41    fn is_terminated(self) -> bool {
42        self.0 & TERMINATED_BIT != 0
43    }
44
45    fn refcount(self) -> usize {
46        self.0 >> BITS_COUNT
47    }
48}
49
50/// A wrapper around a transport which connectivity semantics.
51///
52/// The [`Transport`] trait only provides the bare minimum API surface required
53/// to send and receive data. On top of that, FIDL requires that clients and
54/// servers respect additional messaging semantics. Those semantics are provided
55/// by [`Connection`]:
56///
57/// - `Transport`s are difficult to close because they may be accessed from
58///   several threads simultaneously. `Connection`s provide a mechanism for
59///   gracefully closing transports by causing all sends to pend until the
60///   connection is terminated, and all receives to fail instead of pend.
61/// - FIDL connections may send and receive an epitaph as the final message
62///   before the underlying transport is closed. This epitaph should be provided
63///   to all sends when they fail, which requires additional coordination.
64pub struct Connection<T: Transport> {
65    // The lowest `BITS_COUNT` of this field contain flags indicating the
66    // current state of the transport. The remainder of the upper bits contain
67    // the number of threads attempting to access the `shared` field.
68    state: AtomicUsize,
69    // A thread will drop `shared` if:
70    //
71    // - the connection is dropped before being terminated, or
72    // - it set `TERMINATED_BIT` while the refcount was 0, or
73    // - it decremented the refcount to 0 while `TERMINATED_BIT` was set.
74    //
75    // These cases are handled by `drop`, `terminate`, and `with_shared`
76    // respectively.
77    shared: UnsafeCell<ManuallyDrop<T::Shared>>,
78    stop_waker: AtomicWaker,
79    // TODO: switch this to intrusive linked list in send futures
80    termination_wakers: Mutex<Vec<Waker>>,
81    // Initialized if `TERMINATED_BIT` is set.
82    termination_reason: UnsafeCell<MaybeUninit<ProtocolError<T::Error>>>,
83}
84
85unsafe impl<T: Transport> Send for Connection<T> {}
86unsafe impl<T: Transport> Sync for Connection<T> {}
87
88impl<T: Transport> Drop for Connection<T> {
89    fn drop(&mut self) {
90        self.state.with_mut(|state| {
91            let state = State(*state);
92
93            if !state.is_terminated() {
94                self.shared.with_mut(|shared| {
95                    // SAFETY: The connection was not terminated before being
96                    // dropped, so `shared` has not yet been dropped.
97                    unsafe {
98                        ManuallyDrop::drop(&mut *shared);
99                    }
100                });
101            } else {
102                self.termination_reason.with_mut(|termination_reason| {
103                    // SAFETY: The connection was terminated before being
104                    // dropped, so `termination_reason` is initialized.
105                    unsafe {
106                        MaybeUninit::assume_init_drop(&mut *termination_reason);
107                    }
108                });
109            }
110        });
111    }
112}
113
114impl<T: Transport> Connection<T> {
115    /// Creates a new connection from the shared part of a transport.
116    pub fn new(shared: T::Shared) -> Self {
117        Self {
118            state: AtomicUsize::new(0),
119            shared: UnsafeCell::new(ManuallyDrop::new(shared)),
120            stop_waker: AtomicWaker::new(),
121            termination_wakers: Mutex::new(Vec::new()),
122            termination_reason: UnsafeCell::new(MaybeUninit::uninit()),
123        }
124    }
125
126    /// # Safety
127    ///
128    /// This thread must have loaded `state` with at least `Ordering::Acquire`
129    /// and observed that `TERMINATED_BIT` was set.
130    unsafe fn get_termination_reason_unchecked(&self) -> ProtocolError<T::Error> {
131        self.termination_reason.with(|termination_reason| {
132            // SAFETY: The caller guaranteed that `state` was loaded with at
133            // least `Ordering::Acquire` ordering and observed that
134            // `TERMINATED_BIT` was set.
135            unsafe { MaybeUninit::assume_init_ref(&*termination_reason).clone() }
136        })
137    }
138
139    /// Returns the termination reason for the connection, if any.
140    pub fn get_termination_reason(&self) -> Option<ProtocolError<T::Error>> {
141        if State(self.state.load(Ordering::Acquire)).is_terminated() {
142            // SAFETY: We loaded the state with `Ordering::Acquire` and observed
143            // that `TERMINATED_BIT` was set.
144            unsafe { Some(self.get_termination_reason_unchecked()) }
145        } else {
146            None
147        }
148    }
149
150    /// # Safety
151    ///
152    /// `shared` must not have been dropped. See the documentation on `shared`
153    /// for acceptable criteria.
154    unsafe fn get_shared_unchecked(&self) -> &T::Shared {
155        self.shared.with(|shared| {
156            // SAFETY: The caller guaranteed that `shared` has not been dropped.
157            unsafe { &*shared }
158        })
159    }
160
161    fn with_shared<U>(
162        &self,
163        success: impl FnOnce(&T::Shared) -> U,
164        failure: impl FnOnce(Option<ProtocolError<T::Error>>) -> U,
165    ) -> U {
166        let pre_increment = State(self.state.fetch_add(REFCOUNT, Ordering::Acquire));
167
168        // After the refcount drops to zero (and `shared` is dropped), threads
169        // may still increment and decrement the refcount to attempt to read it.
170        // To avoid dropping `shared` more than once, we prevent the refcount
171        // from being decremented to 0 more than once after `TERMINATED_BIT` is
172        // set.
173        //
174        // We do this by having each thread check whether its increment changed
175        // the refcount from 0 to 1 while `TERMINATED_BIT` was set. If it did,
176        // the thread will not decrement that refcount, leaving it "dangling"
177        // instead. This ensures that the refcount never falls below 1 again.
178        if pre_increment.is_terminated() && pre_increment.refcount() == 0 {
179            // SAFETY: We loaded `state` with `Ordering::Acquire` and observed
180            // that `TERMINATED_BIT` was set.
181            let termination_reason = unsafe { self.get_termination_reason_unchecked() };
182            return failure(Some(termination_reason));
183        }
184
185        let mut success_result = None;
186        if !pre_increment.is_stopping() {
187            // SAFETY: Termination always sets `STOPPING_BIT`. We incremented
188            // the refcount while `STOPPING_BIT` was not set, so `shared` won't
189            // be dropped until we decrement our refcount.
190            let shared = unsafe { self.get_shared_unchecked() };
191            success_result = Some(success(shared));
192        }
193
194        let pre_decrement = State(self.state.fetch_sub(REFCOUNT, Ordering::AcqRel));
195
196        if !pre_decrement.is_stopping() {
197            success_result.unwrap()
198        } else if !pre_decrement.is_terminated() {
199            failure(None)
200        } else {
201            // The connection is terminated. If we decremented the refcount to
202            // 0, then we need to drop `shared`.
203            if pre_decrement.refcount() == 1 {
204                self.shared.with_mut(|shared| {
205                    // SAFETY: We decremented the refcount to 0 while
206                    // `TERMINATED_BIT` was set.
207                    unsafe {
208                        ManuallyDrop::drop(&mut *shared);
209                    }
210                });
211            }
212
213            // SAFETY: We loaded `state` with `Ordering::Acquire` and observed
214            // that `TERMINATED_BIT` was set.
215            let termination_reason = unsafe { self.get_termination_reason_unchecked() };
216            failure(Some(termination_reason))
217        }
218    }
219
220    /// Sends a message to the underlying transport.
221    ///
222    /// Returns a `SendFutureState` which can be polled to completion.
223    pub fn send_message_raw(
224        &self,
225        f: impl FnOnce(&mut T::SendBuffer) -> Result<(), EncodeError>,
226    ) -> Result<SendFutureState<T>, EncodeError> {
227        self.with_shared(
228            |shared| {
229                let mut buffer = T::acquire(shared);
230                f(&mut buffer)?;
231                Ok(SendFutureState::Running { future_state: T::begin_send(shared, buffer) })
232            },
233            |error| {
234                Ok(error
235                    // Some(Error) => Terminated
236                    .map(|error| SendFutureState::Terminated { error })
237                    // None => Stopping
238                    .unwrap_or(SendFutureState::Stopping))
239            },
240        )
241    }
242
243    /// Sends a message to the underlying transport.
244    pub fn send_message(
245        &self,
246        f: impl FnOnce(&mut T::SendBuffer) -> Result<(), EncodeError>,
247    ) -> Result<SendFuture<'_, T>, EncodeError> {
248        Ok(SendFuture { connection: self, state: self.send_message_raw(f)? })
249    }
250
251    /// Sends an epitaph to the underlying transport.
252    ///
253    /// This send ignores the current state of the connection, and does not
254    /// report back any errors encountered while sending.
255    ///
256    /// # Safety
257    ///
258    /// The connection must not be terminated, and the returned future must be
259    /// completed or canceled before the connection is terminated.
260    pub unsafe fn send_epitaph(&self, error: i32) -> SendEpitaphFuture<'_, T> {
261        // SAFETY: The caller has guaranteed that the connection is not
262        // terminated, and will not be terminated until the returned future is
263        // completed or canceled. As long as the connection is not terminated,
264        // `shared` will not be dropped.
265        let shared = unsafe { self.get_shared_unchecked() };
266
267        let mut buffer = T::acquire(shared);
268        encode_header::<T>(&mut buffer, 0, ORDINAL_EPITAPH).unwrap();
269        encode_epitaph::<T>(&mut buffer, error).unwrap();
270        let future_state = T::begin_send(shared, buffer);
271
272        SendEpitaphFuture { shared, future_state }
273    }
274
275    /// Returns a new [`RecvFuture`] which receives the next message.
276    ///
277    /// # Safety
278    ///
279    /// The connection must not be terminated, and the returned future must be
280    /// completed or canceled before the connection is terminated.
281    pub unsafe fn recv<'a>(&'a self, exclusive: &'a mut T::Exclusive) -> RecvFuture<'a, T> {
282        // SAFETY: The caller has guaranteed that the connection is not
283        // terminated, and will not be terminated until the returned future is
284        // completed or canceled. As long as the connection is not terminated,
285        // `shared` will not be dropped.
286        let shared = unsafe { self.get_shared_unchecked() };
287        let future_state = T::begin_recv(shared, exclusive);
288
289        RecvFuture { connection: self, exclusive, future_state }
290    }
291
292    /// Stops the connection to wait for termination.
293    ///
294    /// This modifies the behavior of this connection's futures:
295    ///
296    /// - Polled [`SendFutureState`]s will return `Poll::Pending` without
297    ///   calling [`poll_send`].
298    /// - Polled [`RecvFuture`]s will call [`poll_recv`], but will return
299    ///   `Poll::Ready` with an error when they would normally return
300    ///   `Poll::Pending`.
301    ///
302    /// [`poll_send`]: Transport::poll_send
303    /// [`poll_recv`]: Transport::poll_recv
304    pub fn stop(&self) {
305        let prev_state = State(self.state.fetch_or(STOPPING_BIT, Ordering::Relaxed));
306        if !prev_state.is_stopping() {
307            self.stop_waker.wake();
308        }
309    }
310
311    /// Terminates the connection.
312    ///
313    /// This causes this connection's futures to return `Poll::Ready` with an
314    /// error of the given termination reason.
315    ///
316    /// # Safety
317    ///
318    /// `terminate` may only be called once per connection.
319    pub unsafe fn terminate(&self, reason: ProtocolError<T::Error>) {
320        self.termination_reason.with_mut(|termination_reason| {
321            // SAFETY: The caller guaranteed that this is the only time
322            // `terminate` is called on this connection.
323            unsafe {
324                termination_reason.write(MaybeUninit::new(reason));
325            }
326        });
327        let pre_terminate =
328            State(self.state.fetch_or(STOPPING_BIT | TERMINATED_BIT, Ordering::AcqRel));
329
330        // If we set `TERMINATED_BIT` and the refcount was 0, then we need to
331        // drop `shared`.
332        if !pre_terminate.is_terminated() && pre_terminate.refcount() == 0 {
333            self.shared.with_mut(|shared| {
334                // SAFETY: We set `TERMINATED_BIT` while the refcount was 0.
335                unsafe {
336                    ManuallyDrop::drop(&mut *shared);
337                }
338            });
339        }
340
341        // Wake all of the futures waiting for a termination reason
342        let wakers = take(&mut *self.termination_wakers.lock().unwrap());
343        for waker in wakers {
344            waker.wake();
345        }
346    }
347}
348
349pub type SendFutureOutput<T> = Result<(), ProtocolError<<T as Transport>::Error>>;
350
351#[pin_project(project = SendFutureStateProj, project_replace = SendFutureStateProjOwn)]
352pub enum SendFutureState<T: Transport> {
353    Running {
354        #[pin]
355        future_state: T::SendFutureState,
356    },
357    Stopping,
358    Terminated {
359        error: ProtocolError<T::Error>,
360    },
361    Waiting {
362        waker_index: usize,
363    },
364    Finished,
365}
366
367impl<T: Transport> SendFutureState<T> {
368    fn register_termination_waker(
369        mut self: Pin<&mut Self>,
370        cx: &mut Context<'_>,
371        connection: &Connection<T>,
372        waker_index: Option<usize>,
373    ) -> Poll<SendFutureOutput<T>> {
374        let mut wakers = connection.termination_wakers.lock().unwrap();
375
376        // Re-check the state now that we're holding the lock again. This
377        // prevents us from adding wakers after termination (which would "leak"
378        // them).
379        if let Some(termination_reason) = connection.get_termination_reason() {
380            Poll::Ready(Err(termination_reason))
381        } else {
382            let waker = cx.waker().clone();
383            if let Some(waker_index) = waker_index {
384                // Overwrite an existing waker
385                let old_waker = replace(&mut wakers[waker_index], waker);
386
387                // Drop the old waker outside of the mutex lock
388                drop(wakers);
389                drop(old_waker);
390            } else {
391                // Insert a new waker
392                let waker_index = wakers.len();
393                wakers.push(waker);
394
395                // Update the state outside of the mutex lock. If we were
396                // running then a `T::SendFutureState` may be dropped.
397                drop(wakers);
398                self.set(SendFutureState::Waiting { waker_index });
399            }
400            Poll::Pending
401        }
402    }
403
404    pub fn poll_send(
405        mut self: Pin<&mut Self>,
406        cx: &mut Context<'_>,
407        connection: &Connection<T>,
408    ) -> Poll<SendFutureOutput<T>> {
409        match self.as_mut().project() {
410            SendFutureStateProj::Running { future_state } => {
411                let result = connection.with_shared(
412                    |shared| {
413                        T::poll_send(future_state, cx, shared)
414                            // `Err(Some(error))` =>
415                            //   `Err(Some(TransportError(error)))`
416                            .map_err(|error| error.map(ProtocolError::TransportError))
417                    },
418                    |error| Poll::Ready(Err(error)),
419                );
420
421                let result = match result {
422                    Poll::Pending => Poll::Pending,
423                    Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
424                    Poll::Ready(Err(None)) => {
425                        self.as_mut().register_termination_waker(cx, connection, None)
426                    }
427                    Poll::Ready(Err(Some(error))) => Poll::Ready(Err(error)),
428                };
429
430                if result.is_ready() {
431                    self.set(Self::Finished);
432                }
433
434                result
435            }
436            SendFutureStateProj::Stopping => self.register_termination_waker(cx, connection, None),
437            SendFutureStateProj::Terminated { .. } => {
438                let state = self.project_replace(Self::Finished);
439                let SendFutureStateProjOwn::Terminated { error } = state else {
440                    // SAFETY: We just checked that our state is Terminated.
441                    unsafe { unreachable_unchecked() }
442                };
443                Poll::Ready(Err(error))
444            }
445            SendFutureStateProj::Waiting { waker_index } => {
446                let waker_index = *waker_index;
447                self.register_termination_waker(cx, connection, Some(waker_index))
448            }
449            SendFutureStateProj::Finished => {
450                panic!("SendFuture polled after returning `Poll::Ready`")
451            }
452        }
453    }
454
455    pub fn send_immediately(self, connection: &Connection<T>) -> SendFutureOutput<T>
456    where
457        T: NonBlockingTransport,
458    {
459        match self {
460            SendFutureState::Running { mut future_state } => {
461                connection.with_shared(
462                    |shared| {
463                        // Connection is running, try to send immediately.
464                        T::send_immediately(&mut future_state, shared).map_err(|e| {
465                            // Immediate send failed:
466                            // - `None` => `PeerClosed`
467                            // - `Some(T::Error)` => `TransportError(T::Error)`
468                            e.map_or(ProtocolError::PeerClosed, ProtocolError::TransportError)
469                        })
470                    },
471                    // Getting shared failed, but we may have a termination
472                    // reason. If we don't have one, return `Stopped`.
473                    |error| Err(error.unwrap_or(ProtocolError::Stopped)),
474                )
475            }
476            SendFutureState::Stopping | SendFutureState::Waiting { waker_index: _ } => {
477                // Try to get the termination reason. If we don't have one yet,
478                // return `Stopped`.
479                Err(connection.get_termination_reason().unwrap_or(ProtocolError::Stopped))
480            }
481            SendFutureState::Terminated { error } => Err(error),
482            SendFutureState::Finished => panic!("SendFuture polled after returning `Poll::Ready`"),
483        }
484    }
485}
486
487/// A future which sends an encoded message to a connection.
488#[must_use = "futures do nothing unless polled"]
489#[pin_project]
490pub struct SendFuture<'a, T: Transport> {
491    connection: &'a Connection<T>,
492    #[pin]
493    state: SendFutureState<T>,
494}
495
496impl<T: NonBlockingTransport> SendFuture<'_, T> {
497    /// Completes the send operation synchronously and without blocking.
498    ///
499    /// Using this method prevents transports from applying backpressure. Prefer
500    /// awaiting when possible to allow for backpressure.
501    ///
502    /// Because failed sends return immediately, `send_immediately` may observe
503    /// transport closure prematurely. This can manifest as this method
504    /// returning `Err(PeerClosed)` or `Err(Stopped)` when it should have
505    /// returned `Err(PeerClosedWithEpitaph)`. Prefer awaiting when possible for
506    /// correctness.
507    pub fn send_immediately(self) -> SendFutureOutput<T> {
508        self.state.send_immediately(self.connection)
509    }
510}
511
512impl<T: Transport> Future for SendFuture<'_, T> {
513    type Output = SendFutureOutput<T>;
514
515    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
516        let this = self.project();
517        this.state.poll_send(cx, this.connection)
518    }
519}
520
521#[pin_project]
522pub struct SendEpitaphFuture<'a, T: Transport> {
523    shared: &'a T::Shared,
524    #[pin]
525    future_state: T::SendFutureState,
526}
527
528impl<T: Transport> Future for SendEpitaphFuture<'_, T> {
529    type Output = Result<(), Option<T::Error>>;
530
531    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
532        let this = self.project();
533        T::poll_send(this.future_state, cx, this.shared)
534    }
535}
536
537/// A future which receives an encoded message over the transport.
538#[must_use = "futures do nothing unless polled"]
539#[pin_project]
540pub struct RecvFuture<'a, T: Transport> {
541    connection: &'a Connection<T>,
542    exclusive: &'a mut T::Exclusive,
543    #[pin]
544    future_state: T::RecvFutureState,
545}
546
547impl<T: Transport> Future for RecvFuture<'_, T> {
548    type Output = Result<T::RecvBuffer, ProtocolError<T::Error>>;
549
550    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
551        let this = self.project();
552
553        // SAFETY: This future is created by `Connection::recv`. The connection
554        // will not be terminated until this is completed or canceled, and so
555        // `shared` will not be dropped.
556        let shared = unsafe { this.connection.get_shared_unchecked() };
557
558        let termination_reason = match T::poll_recv(this.future_state, cx, shared, this.exclusive) {
559            Poll::Pending => {
560                // Receive didn't complete, register waker before
561                // re-checking state.
562                this.connection.stop_waker.register_by_ref(cx.waker());
563                let state = State(this.connection.state.load(Ordering::Relaxed));
564                if state.is_stopping() {
565                    // The connection is stopping. Return an error that the
566                    // connection has been stopped.
567                    ProtocolError::Stopped
568                } else {
569                    // Still running, we'll get polled again later.
570                    return Poll::Pending;
571                }
572            }
573
574            // Receive succeeded.
575            Poll::Ready(Ok(buffer)) => return Poll::Ready(Ok(buffer)),
576
577            // Normal failure: return peer closed error.
578            Poll::Ready(Err(None)) => ProtocolError::PeerClosed,
579
580            // Abnormal failure: return transport error.
581            Poll::Ready(Err(Some(error))) => ProtocolError::TransportError(error),
582        };
583
584        Poll::Ready(Err(termination_reason))
585    }
586}