vfs/
object_request.rs

1// Copyright 2023 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use crate::execution_scope::ExecutionScope;
6use crate::node::{self, Node};
7use crate::ProtocolsExt;
8use fidl::endpoints::{ControlHandle, ProtocolMarker, RequestStream, ServerEnd};
9use fidl::epitaph::ChannelEpitaphExt;
10use fidl::{AsHandleRef, HandleBased};
11use futures::FutureExt;
12use std::future::Future;
13use std::sync::Arc;
14use zx_status::Status;
15use {fidl_fuchsia_io as fio, fuchsia_async as fasync};
16
17/// Wraps the channel provided in the open methods and provide convenience methods for sending
18/// appropriate responses.  It also records actions that should be taken upon successful connection
19/// such as truncating file objects.
20#[derive(Debug)]
21pub struct ObjectRequest {
22    // The channel.
23    object_request: fidl::Channel,
24
25    // What should be sent first.
26    what_to_send: ObjectRequestSend,
27
28    // Attributes required in the open method.
29    attributes: fio::NodeAttributesQuery,
30
31    // Creation attributes.
32    create_attributes: Option<Box<fio::MutableNodeAttributes>>,
33
34    /// Truncate the object before use.
35    pub truncate: bool,
36}
37
38impl ObjectRequest {
39    pub(crate) fn new_deprecated(
40        object_request: fidl::Channel,
41        what_to_send: ObjectRequestSend,
42        attributes: fio::NodeAttributesQuery,
43        create_attributes: Option<&fio::MutableNodeAttributes>,
44        truncate: bool,
45    ) -> Self {
46        assert!(!object_request.is_invalid_handle());
47        let create_attributes = create_attributes.map(|a| Box::new(a.clone()));
48        Self { object_request, what_to_send, attributes, create_attributes, truncate }
49    }
50
51    /// Create a new [`ObjectRequest`] from a set of [`fio::Flags`] and [`fio::Options`]`.
52    pub fn new(flags: fio::Flags, options: &fio::Options, object_request: fidl::Channel) -> Self {
53        Self::new_deprecated(
54            object_request,
55            if flags.get_representation() {
56                ObjectRequestSend::OnRepresentation
57            } else {
58                ObjectRequestSend::Nothing
59            },
60            options.attributes.unwrap_or(fio::NodeAttributesQuery::empty()),
61            options.create_attributes.as_ref(),
62            flags.is_truncate(),
63        )
64    }
65
66    pub(crate) fn what_to_send(&self) -> ObjectRequestSend {
67        self.what_to_send
68    }
69
70    pub fn attributes(&self) -> fio::NodeAttributesQuery {
71        self.attributes
72    }
73
74    pub fn create_attributes(&self) -> Option<&fio::MutableNodeAttributes> {
75        self.create_attributes.as_deref()
76    }
77
78    pub fn options(&self) -> fio::Options {
79        fio::Options {
80            attributes: (!self.attributes.is_empty()).then_some(self.attributes),
81            create_attributes: self
82                .create_attributes
83                .as_ref()
84                .map(|a| fio::MutableNodeAttributes::clone(&a)),
85            ..Default::default()
86        }
87    }
88
89    /// Returns the request stream after sending requested information.
90    pub async fn into_request_stream<T: Representation>(
91        self,
92        connection: &T,
93    ) -> Result<<T::Protocol as ProtocolMarker>::RequestStream, Status> {
94        let stream = fio::NodeRequestStream::from_channel(fasync::Channel::from_channel(
95            self.object_request,
96        ));
97        match self.what_to_send {
98            ObjectRequestSend::OnOpen => {
99                let control_handle = stream.control_handle();
100                let node_info = connection.node_info().await.map_err(|s| {
101                    control_handle.shutdown_with_epitaph(s);
102                    s
103                })?;
104                send_on_open(&stream.control_handle(), node_info)?;
105            }
106            ObjectRequestSend::OnRepresentation => {
107                let control_handle = stream.control_handle();
108                let representation =
109                    connection.get_representation(self.attributes).await.map_err(|s| {
110                        control_handle.shutdown_with_epitaph(s);
111                        s
112                    })?;
113                control_handle
114                    .send_on_representation(representation)
115                    .map_err(|_| Status::PEER_CLOSED)?;
116            }
117            ObjectRequestSend::Nothing => {}
118        }
119        Ok(stream.cast_stream())
120    }
121
122    /// Converts to ServerEnd<T>.
123    pub fn into_server_end<T>(self) -> ServerEnd<T> {
124        ServerEnd::new(self.object_request)
125    }
126
127    /// Extracts the channel (without sending on_open).
128    pub fn into_channel(self) -> fidl::Channel {
129        self.object_request
130    }
131
132    /// Extracts the channel after sending on_open.
133    pub fn into_channel_after_sending_on_open(
134        self,
135        node_info: fio::NodeInfoDeprecated,
136    ) -> Result<fidl::Channel, Status> {
137        let stream = fio::NodeRequestStream::from_channel(fasync::Channel::from_channel(
138            self.object_request,
139        ));
140        send_on_open(&stream.control_handle(), node_info)?;
141        let (inner, _is_terminated) = stream.into_inner();
142        // It's safe to unwrap here because inner is clearly the only Arc reference left.
143        Ok(Arc::try_unwrap(inner).unwrap().into_channel().into())
144    }
145
146    /// Terminates the object request with the given status.
147    pub fn shutdown(self, status: Status) {
148        if self.object_request.is_invalid_handle() {
149            return;
150        }
151        if let ObjectRequestSend::OnOpen = self.what_to_send {
152            let (_, control_handle) = ServerEnd::<fio::NodeMarker>::new(self.object_request)
153                .into_stream_and_control_handle();
154            let _ = control_handle.send_on_open_(status.into_raw(), None);
155            control_handle.shutdown_with_epitaph(status);
156        } else {
157            let _ = self.object_request.close_with_epitaph(status);
158        }
159    }
160
161    /// Calls `f` and sends an error on the object request channel upon failure.
162    pub fn handle<T>(
163        mut self,
164        f: impl FnOnce(ObjectRequestRef<'_>) -> Result<T, Status>,
165    ) -> Option<T> {
166        match f(&mut self) {
167            Ok(o) => Some(o),
168            Err(s) => {
169                self.shutdown(s);
170                None
171            }
172        }
173    }
174
175    /// Calls `f` and sends an error on the object request channel upon failure.
176    pub async fn handle_async(
177        mut self,
178        f: impl AsyncFnOnce(&mut ObjectRequest) -> Result<(), Status>,
179    ) {
180        if let Err(s) = f(&mut self).await {
181            self.shutdown(s);
182        }
183    }
184
185    /// Waits until the request has a request waiting in its channel.  Returns immediately if this
186    /// request requires sending an initial event such as OnOpen or OnRepresentation.  Returns
187    /// `true` if the channel is readable (rather than just closed).
188    pub async fn wait_till_ready(&self) -> bool {
189        if !matches!(self.what_to_send, ObjectRequestSend::Nothing) {
190            return true;
191        }
192        let signals = fasync::OnSignalsRef::new(
193            self.object_request.as_handle_ref(),
194            fidl::Signals::OBJECT_READABLE | fidl::Signals::CHANNEL_PEER_CLOSED,
195        )
196        .await
197        .unwrap();
198        signals.contains(fidl::Signals::OBJECT_READABLE)
199    }
200
201    /// Take the ObjectRequest.  The caller is responsible for sending errors.
202    pub fn take(&mut self) -> ObjectRequest {
203        assert!(!self.object_request.is_invalid_handle());
204        Self {
205            object_request: std::mem::replace(
206                &mut self.object_request,
207                fidl::Handle::invalid().into(),
208            ),
209            what_to_send: self.what_to_send,
210            attributes: self.attributes,
211            create_attributes: self.create_attributes.take(),
212            truncate: self.truncate,
213        }
214    }
215
216    /// Constructs a new connection to `node` and spawns an async `Task` that will handle requests
217    /// on the connection. `f` is a callback that constructs the connection but it will not be
218    /// called if the connection is supposed to be a node connection. This should be called from
219    /// within a [`ObjectRequest::handle_async`] callback.
220    pub async fn create_connection<C, N>(
221        &mut self,
222        scope: ExecutionScope,
223        node: Arc<N>,
224        protocols: impl ProtocolsExt,
225    ) -> Result<(), Status>
226    where
227        C: ConnectionCreator<N>,
228        N: Node,
229    {
230        assert!(!self.object_request.is_invalid_handle());
231        if protocols.is_node() {
232            node::Connection::create(scope, node, protocols, self).await
233        } else {
234            C::create(scope, node, protocols, self).await
235        }
236    }
237
238    /// Constructs a new connection to `node` and spawns an async `Task` that will handle requests
239    /// on the connection. `f` is a callback that constructs the connection but it will not be
240    /// called if the connection is supposed to be a node connection. This should be called from
241    /// within a [`ObjectRequest::handle`] callback.
242    ///
243    /// This method synchronously calls async code and may require spawning an extra Task if the
244    /// async code does something asynchronous. `create_connection` should be preferred if the
245    /// caller is already in an async context.
246    pub fn create_connection_sync<C, N>(
247        self,
248        scope: ExecutionScope,
249        node: Arc<N>,
250        protocols: impl ProtocolsExt,
251    ) where
252        C: ConnectionCreator<N>,
253        N: Node,
254    {
255        assert!(!self.object_request.is_invalid_handle());
256        if protocols.is_node() {
257            self.create_connection_sync_or_spawn::<node::Connection<N>, N>(scope, node, protocols);
258        } else {
259            self.create_connection_sync_or_spawn::<C, N>(scope, node, protocols);
260        }
261    }
262
263    fn create_connection_sync_or_spawn<C, N>(
264        self,
265        scope: ExecutionScope,
266        node: Arc<N>,
267        protocols: impl ProtocolsExt,
268    ) where
269        C: ConnectionCreator<N>,
270        N: Node,
271    {
272        let scope2 = scope.clone();
273        let fut = self.handle_async(async |object_request| {
274            C::create(scope2, node, protocols, object_request).await
275        });
276        run_synchronous_future_or_spawn(scope, fut);
277    }
278}
279
280pub type ObjectRequestRef<'a> = &'a mut ObjectRequest;
281
282#[derive(Clone, Copy, Debug, PartialEq)]
283#[allow(dead_code)]
284pub(crate) enum ObjectRequestSend {
285    OnOpen,
286    OnRepresentation,
287    Nothing,
288}
289
290/// Trait to get either fio::Representation or fio::NodeInfoDeprecated.  Connection types
291/// should implement this.
292pub trait Representation {
293    /// The protocol used for the connection.
294    type Protocol: ProtocolMarker;
295
296    /// Returns io2's Representation for the object.
297    fn get_representation(
298        &self,
299        requested_attributes: fio::NodeAttributesQuery,
300    ) -> impl Future<Output = Result<fio::Representation, Status>> + Send;
301
302    /// Returns io1's NodeInfoDeprecated.
303    fn node_info(&self) -> impl Future<Output = Result<fio::NodeInfoDeprecated, Status>> + Send;
304}
305
306/// Convenience trait for converting [`fio::Flags`] and [`fio::OpenFlags`] into ObjectRequest.
307///
308/// If [`fio::Options`] need to be specified, use [`ObjectRequest::new`].
309pub trait ToObjectRequest: ProtocolsExt {
310    fn to_object_request(&self, object_request: impl Into<fidl::Handle>) -> ObjectRequest;
311}
312
313impl ToObjectRequest for fio::OpenFlags {
314    fn to_object_request(&self, object_request: impl Into<fidl::Handle>) -> ObjectRequest {
315        ObjectRequest::new_deprecated(
316            object_request.into().into(),
317            if self.contains(fio::OpenFlags::DESCRIBE) {
318                ObjectRequestSend::OnOpen
319            } else {
320                ObjectRequestSend::Nothing
321            },
322            fio::NodeAttributesQuery::empty(),
323            None,
324            self.is_truncate(),
325        )
326    }
327}
328
329impl ToObjectRequest for fio::Flags {
330    fn to_object_request(&self, object_request: impl Into<fidl::Handle>) -> ObjectRequest {
331        ObjectRequest::new(*self, &Default::default(), object_request.into().into())
332    }
333}
334
335fn send_on_open(
336    control_handle: &fio::NodeControlHandle,
337    node_info: fio::NodeInfoDeprecated,
338) -> Result<(), Status> {
339    control_handle
340        .send_on_open_(Status::OK.into_raw(), Some(node_info))
341        .map_err(|_| Status::PEER_CLOSED)
342}
343
344/// Trait for constructing connections to nodes.
345pub trait ConnectionCreator<T: Node> {
346    /// Creates a new connection to `node` and spawns a new `Task` to run the connection.
347    fn create<'a>(
348        scope: ExecutionScope,
349        node: Arc<T>,
350        protocols: impl ProtocolsExt,
351        object_request: ObjectRequestRef<'a>,
352    ) -> impl Future<Output = Result<(), Status>> + Send + 'a;
353}
354
355/// Synchronously polls `future` with the expectation that it won't return Pending. If the future
356/// does return Pending then this function will spawn a Task to run the future.
357pub(crate) fn run_synchronous_future_or_spawn(
358    scope: ExecutionScope,
359    future: impl Future<Output = ()> + Send + 'static,
360) {
361    let mut task = scope.new_task(future);
362    let noop_waker = std::task::Waker::noop();
363    let mut cx = std::task::Context::from_waker(&noop_waker);
364
365    match task.poll_unpin(&mut cx) {
366        std::task::Poll::Pending => task.spawn(),
367        std::task::Poll::Ready(()) => {}
368    }
369}
370
371#[cfg(test)]
372mod tests {
373    use std::future::ready;
374
375    use crate::execution_scope::yield_to_executor;
376
377    use super::*;
378
379    #[fuchsia::test]
380    async fn test_run_synchronous_future_or_spawn_with_sync_future() {
381        let scope = ExecutionScope::new();
382        run_synchronous_future_or_spawn(scope.clone(), ready(()));
383        scope.wait().await;
384    }
385
386    #[fuchsia::test]
387    async fn test_run_synchronous_future_or_spawn_with_async_future() {
388        let scope = ExecutionScope::new();
389        run_synchronous_future_or_spawn(scope.clone(), yield_to_executor());
390        scope.wait().await;
391    }
392}