fidl_next_protocol/endpoints/
connection.rs

1// Copyright 2025 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use core::future::Future;
6use core::mem::{ManuallyDrop, MaybeUninit, replace, take};
7use core::pin::Pin;
8use core::task::{Context, Poll, Waker};
9use fidl_constants::EPITAPH_ORDINAL;
10
11use fidl_next_codec::EncodeError;
12use pin_project::pin_project;
13
14use crate::concurrency::cell::UnsafeCell;
15use crate::concurrency::future::AtomicWaker;
16use crate::concurrency::hint::unreachable_unchecked;
17use crate::concurrency::sync::Mutex;
18use crate::concurrency::sync::atomic::{AtomicUsize, Ordering};
19use crate::{
20    Flexibility, NonBlockingTransport, ProtocolError, Transport, encode_epitaph, encode_header,
21};
22
23// Indicates that the connection has been requested to stop. Connections are
24// always stopped as they are terminated.
25const STOPPING_BIT: usize = 1 << 0;
26// Indicates that the connection has been provided a termination reason.
27const TERMINATED_BIT: usize = 1 << 1;
28const BITS_COUNT: usize = 2;
29
30// Each refcount represents a thread which is attempting to access the shared
31// part of the transport.
32const REFCOUNT: usize = 1 << BITS_COUNT;
33
34#[derive(Clone, Copy)]
35struct State(usize);
36
37impl State {
38    fn is_stopping(self) -> bool {
39        self.0 & STOPPING_BIT != 0
40    }
41
42    fn is_terminated(self) -> bool {
43        self.0 & TERMINATED_BIT != 0
44    }
45
46    fn refcount(self) -> usize {
47        self.0 >> BITS_COUNT
48    }
49}
50
51/// A wrapper around a transport which connectivity semantics.
52///
53/// The [`Transport`] trait only provides the bare minimum API surface required
54/// to send and receive data. On top of that, FIDL requires that clients and
55/// servers respect additional messaging semantics. Those semantics are provided
56/// by [`Connection`]:
57///
58/// - `Transport`s are difficult to close because they may be accessed from
59///   several threads simultaneously. `Connection`s provide a mechanism for
60///   gracefully closing transports by causing all sends to pend until the
61///   connection is terminated, and all receives to fail instead of pend.
62/// - FIDL connections may send and receive an epitaph as the final message
63///   before the underlying transport is closed. This epitaph should be provided
64///   to all sends when they fail, which requires additional coordination.
65pub struct Connection<T: Transport> {
66    // The lowest `BITS_COUNT` of this field contain flags indicating the
67    // current state of the transport. The remainder of the upper bits contain
68    // the number of threads attempting to access the `shared` field.
69    state: AtomicUsize,
70    // A thread will drop `shared` if:
71    //
72    // - the connection is dropped before being terminated, or
73    // - it set `TERMINATED_BIT` while the refcount was 0, or
74    // - it decremented the refcount to 0 while `TERMINATED_BIT` was set.
75    //
76    // These cases are handled by `drop`, `terminate`, and `with_shared`
77    // respectively.
78    shared: UnsafeCell<ManuallyDrop<T::Shared>>,
79    stop_waker: AtomicWaker,
80    // TODO: switch this to intrusive linked list in send futures
81    termination_wakers: Mutex<Vec<Waker>>,
82    // Initialized if `TERMINATED_BIT` is set.
83    termination_reason: UnsafeCell<MaybeUninit<ProtocolError<T::Error>>>,
84}
85
86unsafe impl<T: Transport> Send for Connection<T> {}
87unsafe impl<T: Transport> Sync for Connection<T> {}
88
89impl<T: Transport> Drop for Connection<T> {
90    fn drop(&mut self) {
91        self.state.with_mut(|state| {
92            let state = State(*state);
93
94            if !state.is_terminated() {
95                self.shared.with_mut(|shared| {
96                    // SAFETY: The connection was not terminated before being
97                    // dropped, so `shared` has not yet been dropped.
98                    unsafe {
99                        ManuallyDrop::drop(&mut *shared);
100                    }
101                });
102            } else {
103                self.termination_reason.with_mut(|termination_reason| {
104                    // SAFETY: The connection was terminated before being
105                    // dropped, so `termination_reason` is initialized.
106                    unsafe {
107                        MaybeUninit::assume_init_drop(&mut *termination_reason);
108                    }
109                });
110            }
111        });
112    }
113}
114
115impl<T: Transport> Connection<T> {
116    /// Creates a new connection from the shared part of a transport.
117    pub fn new(shared: T::Shared) -> Self {
118        Self {
119            state: AtomicUsize::new(0),
120            shared: UnsafeCell::new(ManuallyDrop::new(shared)),
121            stop_waker: AtomicWaker::new(),
122            termination_wakers: Mutex::new(Vec::new()),
123            termination_reason: UnsafeCell::new(MaybeUninit::uninit()),
124        }
125    }
126
127    /// # Safety
128    ///
129    /// This thread must have loaded `state` with at least `Ordering::Acquire`
130    /// and observed that `TERMINATED_BIT` was set.
131    unsafe fn get_termination_reason_unchecked(&self) -> ProtocolError<T::Error> {
132        self.termination_reason.with(|termination_reason| {
133            // SAFETY: The caller guaranteed that `state` was loaded with at
134            // least `Ordering::Acquire` ordering and observed that
135            // `TERMINATED_BIT` was set.
136            unsafe { MaybeUninit::assume_init_ref(&*termination_reason).clone() }
137        })
138    }
139
140    /// Returns the termination reason for the connection, if any.
141    pub fn get_termination_reason(&self) -> Option<ProtocolError<T::Error>> {
142        if State(self.state.load(Ordering::Acquire)).is_terminated() {
143            // SAFETY: We loaded the state with `Ordering::Acquire` and observed
144            // that `TERMINATED_BIT` was set.
145            unsafe { Some(self.get_termination_reason_unchecked()) }
146        } else {
147            None
148        }
149    }
150
151    /// # Safety
152    ///
153    /// `shared` must not have been dropped. See the documentation on `shared`
154    /// for acceptable criteria.
155    unsafe fn get_shared_unchecked(&self) -> &T::Shared {
156        self.shared.with(|shared| {
157            // SAFETY: The caller guaranteed that `shared` has not been dropped.
158            unsafe { &*shared }
159        })
160    }
161
162    fn with_shared<U>(
163        &self,
164        success: impl FnOnce(&T::Shared) -> U,
165        failure: impl FnOnce(Option<ProtocolError<T::Error>>) -> U,
166    ) -> U {
167        let pre_increment = State(self.state.fetch_add(REFCOUNT, Ordering::Acquire));
168
169        // After the refcount drops to zero (and `shared` is dropped), threads
170        // may still increment and decrement the refcount to attempt to read it.
171        // To avoid dropping `shared` more than once, we prevent the refcount
172        // from being decremented to 0 more than once after `TERMINATED_BIT` is
173        // set.
174        //
175        // We do this by having each thread check whether its increment changed
176        // the refcount from 0 to 1 while `TERMINATED_BIT` was set. If it did,
177        // the thread will not decrement that refcount, leaving it "dangling"
178        // instead. This ensures that the refcount never falls below 1 again.
179        if pre_increment.is_terminated() && pre_increment.refcount() == 0 {
180            // SAFETY: We loaded `state` with `Ordering::Acquire` and observed
181            // that `TERMINATED_BIT` was set.
182            let termination_reason = unsafe { self.get_termination_reason_unchecked() };
183            return failure(Some(termination_reason));
184        }
185
186        let mut success_result = None;
187        if !pre_increment.is_stopping() {
188            // SAFETY: Termination always sets `STOPPING_BIT`. We incremented
189            // the refcount while `STOPPING_BIT` was not set, so `shared` won't
190            // be dropped until we decrement our refcount.
191            let shared = unsafe { self.get_shared_unchecked() };
192            success_result = Some(success(shared));
193        }
194
195        let pre_decrement = State(self.state.fetch_sub(REFCOUNT, Ordering::AcqRel));
196
197        if !pre_decrement.is_stopping() {
198            success_result.unwrap()
199        } else if !pre_decrement.is_terminated() {
200            failure(None)
201        } else {
202            // The connection is terminated. If we decremented the refcount to
203            // 0, then we need to drop `shared`.
204            if pre_decrement.refcount() == 1 {
205                self.shared.with_mut(|shared| {
206                    // SAFETY: We decremented the refcount to 0 while
207                    // `TERMINATED_BIT` was set.
208                    unsafe {
209                        ManuallyDrop::drop(&mut *shared);
210                    }
211                });
212            }
213
214            // SAFETY: We loaded `state` with `Ordering::Acquire` and observed
215            // that `TERMINATED_BIT` was set.
216            let termination_reason = unsafe { self.get_termination_reason_unchecked() };
217            failure(Some(termination_reason))
218        }
219    }
220
221    /// Sends a message to the underlying transport.
222    ///
223    /// Returns a `SendFutureState` which can be polled to completion.
224    pub fn send_message_raw(
225        &self,
226        f: impl FnOnce(&mut T::SendBuffer) -> Result<(), EncodeError>,
227    ) -> Result<SendFutureState<T>, EncodeError> {
228        self.with_shared(
229            |shared| {
230                let mut buffer = T::acquire(shared);
231                f(&mut buffer)?;
232                Ok(SendFutureState::Running { future_state: T::begin_send(shared, buffer) })
233            },
234            |error| {
235                Ok(error
236                    // Some(Error) => Terminated
237                    .map(|error| SendFutureState::Terminated { error })
238                    // None => Stopping
239                    .unwrap_or(SendFutureState::Stopping))
240            },
241        )
242    }
243
244    /// Sends a message to the underlying transport.
245    pub fn send_message(
246        &self,
247        f: impl FnOnce(&mut T::SendBuffer) -> Result<(), EncodeError>,
248    ) -> Result<SendFuture<'_, T>, EncodeError> {
249        Ok(SendFuture { connection: self, state: self.send_message_raw(f)? })
250    }
251
252    /// Sends an epitaph to the underlying transport.
253    ///
254    /// This send ignores the current state of the connection, and does not
255    /// report back any errors encountered while sending.
256    ///
257    /// # Safety
258    ///
259    /// The connection must not be terminated, and the returned future must be
260    /// completed or canceled before the connection is terminated.
261    pub unsafe fn send_epitaph(&self, error: i32) -> SendEpitaphFuture<'_, T> {
262        // SAFETY: The caller has guaranteed that the connection is not
263        // terminated, and will not be terminated until the returned future is
264        // completed or canceled. As long as the connection is not terminated,
265        // `shared` will not be dropped.
266        let shared = unsafe { self.get_shared_unchecked() };
267
268        let mut buffer = T::acquire(shared);
269        encode_header::<T>(&mut buffer, 0, EPITAPH_ORDINAL, Flexibility::Strict).unwrap();
270        encode_epitaph::<T>(&mut buffer, error).unwrap();
271        let future_state = T::begin_send(shared, buffer);
272
273        SendEpitaphFuture { shared, future_state }
274    }
275
276    /// Returns a new [`RecvFuture`] which receives the next message.
277    ///
278    /// # Safety
279    ///
280    /// The connection must not be terminated, and the returned future must be
281    /// completed or canceled before the connection is terminated.
282    pub unsafe fn recv<'a>(&'a self, exclusive: &'a mut T::Exclusive) -> RecvFuture<'a, T> {
283        // SAFETY: The caller has guaranteed that the connection is not
284        // terminated, and will not be terminated until the returned future is
285        // completed or canceled. As long as the connection is not terminated,
286        // `shared` will not be dropped.
287        let shared = unsafe { self.get_shared_unchecked() };
288        let future_state = T::begin_recv(shared, exclusive);
289
290        RecvFuture { connection: self, exclusive, future_state }
291    }
292
293    /// Stops the connection to wait for termination.
294    ///
295    /// This modifies the behavior of this connection's futures:
296    ///
297    /// - Polled [`SendFutureState`]s will return `Poll::Pending` without
298    ///   calling [`poll_send`].
299    /// - Polled [`RecvFuture`]s will call [`poll_recv`], but will return
300    ///   `Poll::Ready` with an error when they would normally return
301    ///   `Poll::Pending`.
302    ///
303    /// [`poll_send`]: Transport::poll_send
304    /// [`poll_recv`]: Transport::poll_recv
305    pub fn stop(&self) {
306        let prev_state = State(self.state.fetch_or(STOPPING_BIT, Ordering::Relaxed));
307        if !prev_state.is_stopping() {
308            self.stop_waker.wake();
309        }
310    }
311
312    /// Terminates the connection.
313    ///
314    /// This causes this connection's futures to return `Poll::Ready` with an
315    /// error of the given termination reason.
316    ///
317    /// # Safety
318    ///
319    /// `terminate` may only be called once per connection.
320    pub unsafe fn terminate(&self, reason: ProtocolError<T::Error>) {
321        self.termination_reason.with_mut(|termination_reason| {
322            // SAFETY: The caller guaranteed that this is the only time
323            // `terminate` is called on this connection.
324            unsafe {
325                termination_reason.write(MaybeUninit::new(reason));
326            }
327        });
328        let pre_terminate =
329            State(self.state.fetch_or(STOPPING_BIT | TERMINATED_BIT, Ordering::AcqRel));
330
331        // If we set `TERMINATED_BIT` and the refcount was 0, then we need to
332        // drop `shared`.
333        if !pre_terminate.is_terminated() && pre_terminate.refcount() == 0 {
334            self.shared.with_mut(|shared| {
335                // SAFETY: We set `TERMINATED_BIT` while the refcount was 0.
336                unsafe {
337                    ManuallyDrop::drop(&mut *shared);
338                }
339            });
340        }
341
342        // Wake all of the futures waiting for a termination reason
343        let wakers = take(&mut *self.termination_wakers.lock().unwrap());
344        for waker in wakers {
345            waker.wake();
346        }
347    }
348}
349
350pub type SendFutureOutput<T> = Result<(), ProtocolError<<T as Transport>::Error>>;
351
352#[pin_project(project = SendFutureStateProj, project_replace = SendFutureStateProjOwn)]
353pub enum SendFutureState<T: Transport> {
354    Running {
355        #[pin]
356        future_state: T::SendFutureState,
357    },
358    Stopping,
359    Terminated {
360        error: ProtocolError<T::Error>,
361    },
362    Waiting {
363        waker_index: usize,
364    },
365    Finished,
366}
367
368impl<T: Transport> SendFutureState<T> {
369    fn register_termination_waker(
370        mut self: Pin<&mut Self>,
371        cx: &mut Context<'_>,
372        connection: &Connection<T>,
373        waker_index: Option<usize>,
374    ) -> Poll<SendFutureOutput<T>> {
375        let mut wakers = connection.termination_wakers.lock().unwrap();
376
377        // Re-check the state now that we're holding the lock again. This
378        // prevents us from adding wakers after termination (which would "leak"
379        // them).
380        if let Some(termination_reason) = connection.get_termination_reason() {
381            Poll::Ready(Err(termination_reason))
382        } else {
383            let waker = cx.waker().clone();
384            if let Some(waker_index) = waker_index {
385                // Overwrite an existing waker
386                let old_waker = replace(&mut wakers[waker_index], waker);
387
388                // Drop the old waker outside of the mutex lock
389                drop(wakers);
390                drop(old_waker);
391            } else {
392                // Insert a new waker
393                let waker_index = wakers.len();
394                wakers.push(waker);
395
396                // Update the state outside of the mutex lock. If we were
397                // running then a `T::SendFutureState` may be dropped.
398                drop(wakers);
399                self.set(SendFutureState::Waiting { waker_index });
400            }
401            Poll::Pending
402        }
403    }
404
405    pub fn poll_send(
406        mut self: Pin<&mut Self>,
407        cx: &mut Context<'_>,
408        connection: &Connection<T>,
409    ) -> Poll<SendFutureOutput<T>> {
410        match self.as_mut().project() {
411            SendFutureStateProj::Running { future_state } => {
412                let result = connection.with_shared(
413                    |shared| {
414                        T::poll_send(future_state, cx, shared)
415                            // `Err(Some(error))` =>
416                            //   `Err(Some(TransportError(error)))`
417                            .map_err(|error| error.map(ProtocolError::TransportError))
418                    },
419                    |error| Poll::Ready(Err(error)),
420                );
421
422                let result = match result {
423                    Poll::Pending => Poll::Pending,
424                    Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
425                    Poll::Ready(Err(None)) => {
426                        self.as_mut().register_termination_waker(cx, connection, None)
427                    }
428                    Poll::Ready(Err(Some(error))) => Poll::Ready(Err(error)),
429                };
430
431                if result.is_ready() {
432                    self.set(Self::Finished);
433                }
434
435                result
436            }
437            SendFutureStateProj::Stopping => self.register_termination_waker(cx, connection, None),
438            SendFutureStateProj::Terminated { .. } => {
439                let state = self.project_replace(Self::Finished);
440                let SendFutureStateProjOwn::Terminated { error } = state else {
441                    // SAFETY: We just checked that our state is Terminated.
442                    unsafe { unreachable_unchecked() }
443                };
444                Poll::Ready(Err(error))
445            }
446            SendFutureStateProj::Waiting { waker_index } => {
447                let waker_index = *waker_index;
448                self.register_termination_waker(cx, connection, Some(waker_index))
449            }
450            SendFutureStateProj::Finished => {
451                panic!("SendFuture polled after returning `Poll::Ready`")
452            }
453        }
454    }
455
456    pub fn send_immediately(self, connection: &Connection<T>) -> SendFutureOutput<T>
457    where
458        T: NonBlockingTransport,
459    {
460        match self {
461            SendFutureState::Running { mut future_state } => {
462                connection.with_shared(
463                    |shared| {
464                        // Connection is running, try to send immediately.
465                        T::send_immediately(&mut future_state, shared).map_err(|e| {
466                            // Immediate send failed:
467                            // - `None` => `PeerClosed`
468                            // - `Some(T::Error)` => `TransportError(T::Error)`
469                            e.map_or(ProtocolError::PeerClosed, ProtocolError::TransportError)
470                        })
471                    },
472                    // Getting shared failed, but we may have a termination
473                    // reason. If we don't have one, return `Stopped`.
474                    |error| Err(error.unwrap_or(ProtocolError::Stopped)),
475                )
476            }
477            SendFutureState::Stopping | SendFutureState::Waiting { waker_index: _ } => {
478                // Try to get the termination reason. If we don't have one yet,
479                // return `Stopped`.
480                Err(connection.get_termination_reason().unwrap_or(ProtocolError::Stopped))
481            }
482            SendFutureState::Terminated { error } => Err(error),
483            SendFutureState::Finished => panic!("SendFuture polled after returning `Poll::Ready`"),
484        }
485    }
486}
487
488/// A future which sends an encoded message to a connection.
489#[must_use = "futures do nothing unless polled"]
490#[pin_project]
491pub struct SendFuture<'a, T: Transport> {
492    connection: &'a Connection<T>,
493    #[pin]
494    state: SendFutureState<T>,
495}
496
497impl<T: NonBlockingTransport> SendFuture<'_, T> {
498    /// Completes the send operation synchronously and without blocking.
499    ///
500    /// Using this method prevents transports from applying backpressure. Prefer
501    /// awaiting when possible to allow for backpressure.
502    ///
503    /// Because failed sends return immediately, `send_immediately` may observe
504    /// transport closure prematurely. This can manifest as this method
505    /// returning `Err(PeerClosed)` or `Err(Stopped)` when it should have
506    /// returned `Err(PeerClosedWithEpitaph)`. Prefer awaiting when possible for
507    /// correctness.
508    pub fn send_immediately(self) -> SendFutureOutput<T> {
509        self.state.send_immediately(self.connection)
510    }
511}
512
513impl<T: Transport> Future for SendFuture<'_, T> {
514    type Output = SendFutureOutput<T>;
515
516    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
517        let this = self.project();
518        this.state.poll_send(cx, this.connection)
519    }
520}
521
522#[pin_project]
523pub struct SendEpitaphFuture<'a, T: Transport> {
524    shared: &'a T::Shared,
525    #[pin]
526    future_state: T::SendFutureState,
527}
528
529impl<T: Transport> Future for SendEpitaphFuture<'_, T> {
530    type Output = Result<(), Option<T::Error>>;
531
532    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
533        let this = self.project();
534        T::poll_send(this.future_state, cx, this.shared)
535    }
536}
537
538/// A future which receives an encoded message over the transport.
539#[must_use = "futures do nothing unless polled"]
540#[pin_project]
541pub struct RecvFuture<'a, T: Transport> {
542    connection: &'a Connection<T>,
543    exclusive: &'a mut T::Exclusive,
544    #[pin]
545    future_state: T::RecvFutureState,
546}
547
548impl<T: Transport> Future for RecvFuture<'_, T> {
549    type Output = Result<T::RecvBuffer, ProtocolError<T::Error>>;
550
551    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
552        let this = self.project();
553
554        // SAFETY: This future is created by `Connection::recv`. The connection
555        // will not be terminated until this is completed or canceled, and so
556        // `shared` will not be dropped.
557        let shared = unsafe { this.connection.get_shared_unchecked() };
558
559        let termination_reason = match T::poll_recv(this.future_state, cx, shared, this.exclusive) {
560            Poll::Pending => {
561                // Receive didn't complete, register waker before
562                // re-checking state.
563                this.connection.stop_waker.register_by_ref(cx.waker());
564                let state = State(this.connection.state.load(Ordering::Relaxed));
565                if state.is_stopping() {
566                    // The connection is stopping. Return an error that the
567                    // connection has been stopped.
568                    ProtocolError::Stopped
569                } else {
570                    // Still running, we'll get polled again later.
571                    return Poll::Pending;
572                }
573            }
574
575            // Receive succeeded.
576            Poll::Ready(Ok(buffer)) => return Poll::Ready(Ok(buffer)),
577
578            // Normal failure: return peer closed error.
579            Poll::Ready(Err(None)) => ProtocolError::PeerClosed,
580
581            // Abnormal failure: return transport error.
582            Poll::Ready(Err(Some(error))) => ProtocolError::TransportError(error),
583        };
584
585        Poll::Ready(Err(termination_reason))
586    }
587}