Skip to main content

vfs/
object_request.rs

1// Copyright 2023 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use crate::ProtocolsExt;
6use crate::execution_scope::ExecutionScope;
7use crate::node::{self, Node};
8use fidl::endpoints::{ControlHandle, ProtocolMarker, RequestStream, ServerEnd};
9use fidl::epitaph::ChannelEpitaphExt;
10use futures::FutureExt;
11use std::future::Future;
12use std::sync::Arc;
13use zx_status::Status;
14use {fidl_fuchsia_io as fio, fuchsia_async as fasync};
15
16/// Wraps the channel provided in the open methods and provide convenience methods for sending
17/// appropriate responses.  It also records actions that should be taken upon successful connection
18/// such as truncating file objects.
19#[derive(Debug)]
20pub struct ObjectRequest {
21    // The channel.
22    object_request: fidl::Channel,
23
24    // What should be sent first.
25    what_to_send: ObjectRequestSend,
26
27    // Attributes required in the open method.
28    attributes: fio::NodeAttributesQuery,
29
30    // Creation attributes.
31    create_attributes: Option<Box<fio::MutableNodeAttributes>>,
32
33    /// Truncate the object before use.
34    pub truncate: bool,
35}
36
37impl ObjectRequest {
38    pub(crate) fn new_deprecated(
39        object_request: fidl::Channel,
40        what_to_send: ObjectRequestSend,
41        attributes: fio::NodeAttributesQuery,
42        create_attributes: Option<&fio::MutableNodeAttributes>,
43        truncate: bool,
44    ) -> Self {
45        assert!(!object_request.as_handle_ref().is_invalid());
46        let create_attributes = create_attributes.map(|a| Box::new(a.clone()));
47        Self { object_request, what_to_send, attributes, create_attributes, truncate }
48    }
49
50    /// Create a new [`ObjectRequest`] from a set of [`fio::Flags`] and [`fio::Options`]`.
51    pub fn new(flags: fio::Flags, options: &fio::Options, object_request: fidl::Channel) -> Self {
52        Self::new_deprecated(
53            object_request,
54            if flags.contains(fio::Flags::FLAG_SEND_REPRESENTATION) {
55                ObjectRequestSend::OnRepresentation
56            } else {
57                ObjectRequestSend::Nothing
58            },
59            options.attributes.unwrap_or(fio::NodeAttributesQuery::empty()),
60            options.create_attributes.as_ref(),
61            flags.is_truncate(),
62        )
63    }
64
65    pub(crate) fn what_to_send(&self) -> ObjectRequestSend {
66        self.what_to_send
67    }
68
69    pub fn attributes(&self) -> fio::NodeAttributesQuery {
70        self.attributes
71    }
72
73    pub fn create_attributes(&self) -> Option<&fio::MutableNodeAttributes> {
74        self.create_attributes.as_deref()
75    }
76
77    pub fn options(&self) -> fio::Options {
78        fio::Options {
79            attributes: (!self.attributes.is_empty()).then_some(self.attributes),
80            create_attributes: self
81                .create_attributes
82                .as_ref()
83                .map(|a| fio::MutableNodeAttributes::clone(&a)),
84            ..Default::default()
85        }
86    }
87
88    /// Returns the request stream after sending requested information.
89    pub async fn into_request_stream<T: Representation>(
90        self,
91        connection: &T,
92    ) -> Result<<T::Protocol as ProtocolMarker>::RequestStream, Status> {
93        let stream = fio::NodeRequestStream::from_channel(fasync::Channel::from_channel(
94            self.object_request,
95        ));
96        match self.what_to_send {
97            ObjectRequestSend::OnOpen => {
98                let control_handle = stream.control_handle();
99                let node_info = connection.node_info().await.map_err(|s| {
100                    control_handle.shutdown_with_epitaph(s);
101                    s
102                })?;
103                send_on_open(&stream.control_handle(), node_info)?;
104            }
105            ObjectRequestSend::OnRepresentation => {
106                let control_handle = stream.control_handle();
107                let representation =
108                    connection.get_representation(self.attributes).await.map_err(|s| {
109                        control_handle.shutdown_with_epitaph(s);
110                        s
111                    })?;
112                control_handle
113                    .send_on_representation(representation)
114                    .map_err(|_| Status::PEER_CLOSED)?;
115            }
116            ObjectRequestSend::Nothing => {}
117        }
118        Ok(stream.cast_stream())
119    }
120
121    /// Converts to ServerEnd<T>.
122    pub fn into_server_end<T>(self) -> ServerEnd<T> {
123        ServerEnd::new(self.object_request)
124    }
125
126    /// Extracts the channel (without sending on_open).
127    pub fn into_channel(self) -> fidl::Channel {
128        self.object_request
129    }
130
131    /// Extracts the channel after sending on_open.
132    pub fn into_channel_after_sending_on_open(
133        self,
134        node_info: fio::NodeInfoDeprecated,
135    ) -> Result<fidl::Channel, Status> {
136        let stream = fio::NodeRequestStream::from_channel(fasync::Channel::from_channel(
137            self.object_request,
138        ));
139        send_on_open(&stream.control_handle(), node_info)?;
140        let (inner, _is_terminated) = stream.into_inner();
141        // It's safe to unwrap here because inner is clearly the only Arc reference left.
142        Ok(Arc::try_unwrap(inner).unwrap().into_channel().into())
143    }
144
145    /// Terminates the object request with the given status.
146    pub fn shutdown(self, status: Status) {
147        if self.object_request.as_handle_ref().is_invalid() {
148            return;
149        }
150        if let ObjectRequestSend::OnOpen = self.what_to_send {
151            let (_, control_handle) = ServerEnd::<fio::NodeMarker>::new(self.object_request)
152                .into_stream_and_control_handle();
153            let _ = control_handle.send_on_open_(status.into_raw(), None);
154            control_handle.shutdown_with_epitaph(status);
155        } else {
156            let _ = self.object_request.close_with_epitaph(status);
157        }
158    }
159
160    /// Calls `f` and sends an error on the object request channel upon failure.
161    pub fn handle<T>(
162        mut self,
163        f: impl FnOnce(ObjectRequestRef<'_>) -> Result<T, Status>,
164    ) -> Option<T> {
165        match f(&mut self) {
166            Ok(o) => Some(o),
167            Err(s) => {
168                self.shutdown(s);
169                None
170            }
171        }
172    }
173
174    /// Calls `f` and sends an error on the object request channel upon failure.
175    pub async fn handle_async(
176        mut self,
177        f: impl AsyncFnOnce(&mut ObjectRequest) -> Result<(), Status>,
178    ) {
179        if let Err(s) = f(&mut self).await {
180            self.shutdown(s);
181        }
182    }
183
184    /// Waits until the request has a request waiting in its channel.  Returns immediately if this
185    /// request requires sending an initial event such as OnOpen or OnRepresentation.  Returns
186    /// `true` if the channel is readable (rather than just closed).
187    pub async fn wait_till_ready(&self) -> bool {
188        if !matches!(self.what_to_send, ObjectRequestSend::Nothing) {
189            return true;
190        }
191        let signals = fasync::OnSignalsRef::new(
192            self.object_request.as_handle_ref(),
193            fidl::Signals::OBJECT_READABLE | fidl::Signals::CHANNEL_PEER_CLOSED,
194        )
195        .await
196        .unwrap();
197        signals.contains(fidl::Signals::OBJECT_READABLE)
198    }
199
200    /// Take the ObjectRequest.  The caller is responsible for sending errors.
201    pub fn take(&mut self) -> ObjectRequest {
202        assert!(!self.object_request.as_handle_ref().is_invalid());
203        Self {
204            object_request: std::mem::replace(
205                &mut self.object_request,
206                fidl::NullableHandle::invalid().into(),
207            ),
208            what_to_send: self.what_to_send,
209            attributes: self.attributes,
210            create_attributes: self.create_attributes.take(),
211            truncate: self.truncate,
212        }
213    }
214
215    /// Constructs a new connection to `node` and spawns an async `Task` that will handle requests
216    /// on the connection. `f` is a callback that constructs the connection but it will not be
217    /// called if the connection is supposed to be a node connection. This should be called from
218    /// within a [`ObjectRequest::handle_async`] callback.
219    pub async fn create_connection<C, N>(
220        &mut self,
221        scope: ExecutionScope,
222        node: Arc<N>,
223        protocols: impl ProtocolsExt,
224    ) -> Result<(), Status>
225    where
226        C: ConnectionCreator<N>,
227        N: Node,
228    {
229        assert!(!self.object_request.as_handle_ref().is_invalid());
230        if protocols.is_node() {
231            node::Connection::create(scope, node, protocols, self).await
232        } else {
233            C::create(scope, node, protocols, self).await
234        }
235    }
236
237    /// Constructs a new connection to `node` and spawns an async `Task` that will handle requests
238    /// on the connection. `f` is a callback that constructs the connection but it will not be
239    /// called if the connection is supposed to be a node connection. This should be called from
240    /// within a [`ObjectRequest::handle`] callback.
241    ///
242    /// This method synchronously calls async code and may require spawning an extra Task if the
243    /// async code does something asynchronous. `create_connection` should be preferred if the
244    /// caller is already in an async context.
245    pub fn create_connection_sync<C, N>(
246        self,
247        scope: ExecutionScope,
248        node: Arc<N>,
249        protocols: impl ProtocolsExt,
250    ) where
251        C: ConnectionCreator<N>,
252        N: Node,
253    {
254        assert!(!self.object_request.as_handle_ref().is_invalid());
255        if protocols.is_node() {
256            self.create_connection_sync_or_spawn::<node::Connection<N>, N>(scope, node, protocols);
257        } else {
258            self.create_connection_sync_or_spawn::<C, N>(scope, node, protocols);
259        }
260    }
261
262    fn create_connection_sync_or_spawn<C, N>(
263        self,
264        scope: ExecutionScope,
265        node: Arc<N>,
266        protocols: impl ProtocolsExt,
267    ) where
268        C: ConnectionCreator<N>,
269        N: Node,
270    {
271        let scope2 = scope.clone();
272        let fut = self.handle_async(async |object_request| {
273            C::create(scope2, node, protocols, object_request).await
274        });
275        run_synchronous_future_or_spawn(scope, fut);
276    }
277}
278
279pub type ObjectRequestRef<'a> = &'a mut ObjectRequest;
280
281#[derive(Clone, Copy, Debug, PartialEq)]
282#[allow(dead_code)]
283pub(crate) enum ObjectRequestSend {
284    OnOpen,
285    OnRepresentation,
286    Nothing,
287}
288
289/// Trait to get either fio::Representation or fio::NodeInfoDeprecated.  Connection types
290/// should implement this.
291pub trait Representation {
292    /// The protocol used for the connection.
293    type Protocol: ProtocolMarker;
294
295    /// Returns io2's Representation for the object.
296    fn get_representation(
297        &self,
298        requested_attributes: fio::NodeAttributesQuery,
299    ) -> impl Future<Output = Result<fio::Representation, Status>> + Send;
300
301    /// Returns io1's NodeInfoDeprecated.
302    fn node_info(&self) -> impl Future<Output = Result<fio::NodeInfoDeprecated, Status>> + Send;
303}
304
305/// Convenience trait for converting [`fio::Flags`] and [`fio::OpenFlags`] into ObjectRequest.
306///
307/// If [`fio::Options`] need to be specified, use [`ObjectRequest::new`].
308pub trait ToObjectRequest: ProtocolsExt {
309    fn to_object_request(&self, object_request: impl Into<fidl::NullableHandle>) -> ObjectRequest;
310}
311
312impl ToObjectRequest for fio::OpenFlags {
313    fn to_object_request(&self, object_request: impl Into<fidl::NullableHandle>) -> ObjectRequest {
314        ObjectRequest::new_deprecated(
315            object_request.into().into(),
316            if self.contains(fio::OpenFlags::DESCRIBE) {
317                ObjectRequestSend::OnOpen
318            } else {
319                ObjectRequestSend::Nothing
320            },
321            fio::NodeAttributesQuery::empty(),
322            None,
323            self.is_truncate(),
324        )
325    }
326}
327
328impl ToObjectRequest for fio::Flags {
329    fn to_object_request(&self, object_request: impl Into<fidl::NullableHandle>) -> ObjectRequest {
330        ObjectRequest::new(*self, &Default::default(), object_request.into().into())
331    }
332}
333
334fn send_on_open(
335    control_handle: &fio::NodeControlHandle,
336    node_info: fio::NodeInfoDeprecated,
337) -> Result<(), Status> {
338    control_handle
339        .send_on_open_(Status::OK.into_raw(), Some(node_info))
340        .map_err(|_| Status::PEER_CLOSED)
341}
342
343/// Trait for constructing connections to nodes.
344pub trait ConnectionCreator<T: Node> {
345    /// Creates a new connection to `node` and spawns a new `Task` to run the connection.
346    fn create<'a>(
347        scope: ExecutionScope,
348        node: Arc<T>,
349        protocols: impl ProtocolsExt,
350        object_request: ObjectRequestRef<'a>,
351    ) -> impl Future<Output = Result<(), Status>> + Send + 'a;
352}
353
354/// Synchronously polls `future` with the expectation that it won't return Pending. If the future
355/// does return Pending then this function will spawn a Task to run the future.
356pub(crate) fn run_synchronous_future_or_spawn(
357    scope: ExecutionScope,
358    future: impl Future<Output = ()> + Send + 'static,
359) {
360    let mut task = scope.new_task(future);
361    let mut cx = std::task::Context::from_waker(std::task::Waker::noop());
362
363    match task.poll_unpin(&mut cx) {
364        std::task::Poll::Pending => task.spawn(),
365        std::task::Poll::Ready(()) => {}
366    }
367}
368
369#[cfg(test)]
370mod tests {
371    use std::future::ready;
372
373    use crate::execution_scope::yield_to_executor;
374
375    use super::*;
376
377    #[fuchsia::test]
378    async fn test_run_synchronous_future_or_spawn_with_sync_future() {
379        let scope = ExecutionScope::new();
380        run_synchronous_future_or_spawn(scope.clone(), ready(()));
381        scope.wait().await;
382    }
383
384    #[fuchsia::test]
385    async fn test_run_synchronous_future_or_spawn_with_async_future() {
386        let scope = ExecutionScope::new();
387        run_synchronous_future_or_spawn(scope.clone(), yield_to_executor());
388        scope.wait().await;
389    }
390}