starnix_core/fs/fuchsia/
remote_unix_domain_socket.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use crate::fs::fuchsia::{OpenFlags, new_remote_file};
6use crate::task::{
7    CurrentTask, EventHandler, FullCredentials, SignalHandler, SignalHandlerInner, WaitCanceler,
8    Waiter,
9};
10use crate::vfs::buffers::{InputBuffer, OutputBuffer};
11use crate::vfs::socket::{
12    SockOptValue, Socket, SocketAddress, SocketDomain, SocketHandle, SocketMessageFlags, SocketOps,
13    SocketPeer, SocketProtocol, SocketShutdownFlags, SocketType,
14};
15use crate::vfs::{AncillaryData, FileHandle, MessageReadInfo, UnixControlData};
16use fidl::endpoints::SynchronousProxy;
17use linux_uapi::{SO_LINGER, SOL_SOCKET};
18use starnix_sync::{FileOpsCore, Locked};
19use starnix_uapi::errors::Errno;
20use starnix_uapi::vfs::FdEvents;
21use starnix_uapi::{errno, error, from_status_like_fdio, uapi, ucred};
22use zerocopy::IntoBytes;
23use zx::AsHandleRef;
24use {fidl_fuchsia_io as fio, fidl_fuchsia_starnix_binder as fbinder};
25static READABLE_SIGNAL: zx::Signals =
26    zx::Signals::from_bits_retain(fio::FileSignal::READABLE.bits());
27static WRITABLE_SIGNAL: zx::Signals =
28    zx::Signals::from_bits_retain(fio::FileSignal::WRITABLE.bits());
29
30pub struct RemoteUnixDomainSocket {
31    client: fbinder::UnixDomainSocketSynchronousProxy,
32    event: zx::EventPair,
33    remote_creds: FullCredentials,
34}
35
36impl RemoteUnixDomainSocket {
37    pub fn new(channel: zx::Channel, remote_creds: FullCredentials) -> Result<Self, Errno> {
38        let client = fbinder::UnixDomainSocketSynchronousProxy::from_channel(channel);
39        let response = client
40            .get_event(
41                &fbinder::UnixDomainSocketGetEventRequest::default(),
42                zx::MonotonicInstant::INFINITE,
43            )
44            .map_err(|_| errno!(ECONNREFUSED))?
45            .map_err(|e: i32| from_status_like_fdio!(zx::Status::from_raw(e)))?;
46        let event = response.event.ok_or_else(|| errno!(ECONNREFUSED))?;
47        Ok(Self { client, event, remote_creds })
48    }
49
50    fn get_signals_from_events(events: FdEvents) -> zx::Signals {
51        let mut signals = zx::Signals::NONE;
52        if events.contains(FdEvents::POLLIN) {
53            signals |= READABLE_SIGNAL;
54        }
55        if events.contains(FdEvents::POLLOUT) {
56            signals |= WRITABLE_SIGNAL;
57        }
58        signals
59    }
60
61    fn get_events_from_signals(signals: zx::Signals) -> FdEvents {
62        let mut events = FdEvents::empty();
63        if signals.contains(READABLE_SIGNAL) {
64            events |= FdEvents::POLLIN;
65        }
66        if signals.contains(WRITABLE_SIGNAL) {
67            events |= FdEvents::POLLOUT;
68        }
69        events
70    }
71
72    /// Perform an action using the credentials of the remote task.
73    fn with_remote_creds<F, R>(&self, current_task: &CurrentTask, f: F) -> Result<R, Errno>
74    where
75        F: FnOnce() -> Result<R, Errno>,
76    {
77        current_task.override_creds(|temp_creds| *temp_creds = self.remote_creds.clone(), f)
78    }
79}
80
81impl SocketOps for RemoteUnixDomainSocket {
82    fn get_socket_info(&self) -> Result<(SocketDomain, SocketType, SocketProtocol), Errno> {
83        Ok((SocketDomain::Unix, SocketType::Datagram, SocketProtocol::from_raw(0)))
84    }
85
86    fn connect(
87        &self,
88        _locked: &mut Locked<FileOpsCore>,
89        _socket: &SocketHandle,
90        _current_task: &CurrentTask,
91        _peer: SocketPeer,
92    ) -> Result<(), Errno> {
93        error!(EISCONN)
94    }
95
96    fn listen(
97        &self,
98        _locked: &mut Locked<FileOpsCore>,
99        _socket: &Socket,
100        _backlog: i32,
101        _credentials: ucred,
102    ) -> Result<(), Errno> {
103        error!(EOPNOTSUPP)
104    }
105
106    fn accept(
107        &self,
108        _locked: &mut Locked<FileOpsCore>,
109        _socket: &Socket,
110        _current_task: &CurrentTask,
111    ) -> Result<SocketHandle, Errno> {
112        error!(EOPNOTSUPP)
113    }
114
115    fn bind(
116        &self,
117        _locked: &mut Locked<FileOpsCore>,
118        _socket: &Socket,
119        _current_task: &CurrentTask,
120        _socket_address: SocketAddress,
121    ) -> Result<(), Errno> {
122        error!(EOPNOTSUPP)
123    }
124
125    fn read(
126        &self,
127        locked: &mut Locked<FileOpsCore>,
128        _socket: &Socket,
129        current_task: &CurrentTask,
130        data: &mut dyn OutputBuffer,
131        flags: SocketMessageFlags,
132    ) -> Result<MessageReadInfo, Errno> {
133        if self.client.is_closed().map_err(|_| errno!(ECONNREFUSED))? {
134            return error!(ECONNREFUSED);
135        }
136        let mut read_flags = fbinder::ReadFlags::empty();
137        if flags.contains(SocketMessageFlags::PEEK) {
138            read_flags |= fbinder::ReadFlags::PEEK;
139        }
140
141        let response = self
142            .client
143            .read(
144                &fbinder::UnixDomainSocketReadRequest {
145                    count: Some(data.available() as u64),
146                    flags: Some(read_flags),
147                    ..Default::default()
148                },
149                zx::MonotonicInstant::INFINITE,
150            )
151            .map_err(|_| errno!(ECONNREFUSED))?
152            .map_err(|e: i32| {
153                let status = zx::Status::from_raw(e);
154                if status == zx::Status::PEER_CLOSED {
155                    errno!(ECONNRESET)
156                } else {
157                    from_status_like_fdio!(status)
158                }
159            })?;
160
161        let written =
162            if let Some(received_data) = response.data { data.write(&received_data)? } else { 0 };
163
164        let mut file_handles: Vec<FileHandle> = vec![];
165        if let Some(handles) = response.handles {
166            // Use the remote task's credentials to create the remote_file object. This ensures
167            // that the SID associated to the fd is set to the correct value.
168            self.with_remote_creds(current_task, || {
169                for handle in handles {
170                    file_handles.push(new_remote_file(
171                        locked,
172                        current_task,
173                        handle,
174                        OpenFlags::RDWR,
175                    )?);
176                }
177                Ok(())
178            })?;
179        }
180        let ancillary_data = vec![AncillaryData::Unix(UnixControlData::Rights(file_handles))];
181
182        let message_length = response.data_original_length.unwrap_or(written as u64) as usize;
183
184        Ok(MessageReadInfo { bytes_read: written, message_length, address: None, ancillary_data })
185    }
186
187    fn write(
188        &self,
189        _locked: &mut Locked<FileOpsCore>,
190        _socket: &Socket,
191        current_task: &CurrentTask,
192        data: &mut dyn InputBuffer,
193        _dest_address: &mut Option<SocketAddress>,
194        ancillary_data: &mut Vec<AncillaryData>,
195    ) -> Result<usize, Errno> {
196        if self.client.is_closed().map_err(|_| errno!(ECONNREFUSED))? {
197            return error!(ECONNREFUSED);
198        }
199
200        let mut handles: Vec<zx::NullableHandle> = vec![];
201        for data in ancillary_data {
202            match data {
203                AncillaryData::Unix(UnixControlData::Rights(file_handles)) => {
204                    // Access the served files with the credentials of the remote end.
205                    self.with_remote_creds(current_task, || {
206                        for file_handle in file_handles {
207                            let Some(handle) = file_handle.to_handle(current_task)? else {
208                                return error!(EINVAL);
209                            };
210                            handles.push(handle);
211                        }
212                        Ok(())
213                    })?;
214                }
215                _ => return error!(EINVAL),
216            }
217        }
218
219        let bytes = data.read_all()?;
220
221        let response = self
222            .client
223            .write(
224                fbinder::UnixDomainSocketWriteRequest {
225                    data: Some(bytes),
226                    handles: Some(handles),
227                    ..Default::default()
228                },
229                zx::MonotonicInstant::INFINITE,
230            )
231            .map_err(|_| errno!(ECONNREFUSED))?
232            .map_err(|e: i32| from_status_like_fdio!(zx::Status::from_raw(e)))?;
233
234        let written = response.actual_count.unwrap_or(0);
235        Ok(written as usize)
236    }
237
238    fn wait_async(
239        &self,
240        _locked: &mut Locked<FileOpsCore>,
241        _socket: &Socket,
242        _current_task: &CurrentTask,
243        waiter: &Waiter,
244        events: FdEvents,
245        handler: EventHandler,
246    ) -> WaitCanceler {
247        let signal_handler = SignalHandler {
248            inner: SignalHandlerInner::ZxHandle(Self::get_events_from_signals),
249            event_handler: handler,
250            err_code: None,
251        };
252        let canceler = waiter
253            .wake_on_zircon_signals(
254                &self.event,
255                Self::get_signals_from_events(events),
256                signal_handler,
257            )
258            .unwrap();
259        WaitCanceler::new_port(canceler)
260    }
261
262    fn query_events(
263        &self,
264        _locked: &mut Locked<FileOpsCore>,
265        _socket: &Socket,
266        _current_task: &CurrentTask,
267    ) -> Result<FdEvents, Errno> {
268        let signals = self
269            .event
270            .as_handle_ref()
271            .wait(zx::Signals::NONE, zx::MonotonicInstant::INFINITE_PAST)
272            .map_err(|e| from_status_like_fdio!(e))?;
273        Ok(Self::get_events_from_signals(signals))
274    }
275
276    fn shutdown(
277        &self,
278        _locked: &mut Locked<FileOpsCore>,
279        _socket: &Socket,
280        _how: SocketShutdownFlags,
281    ) -> Result<(), Errno> {
282        Ok(())
283    }
284
285    fn close(
286        &self,
287        _locked: &mut Locked<FileOpsCore>,
288        _current_task: &CurrentTask,
289        _socket: &Socket,
290    ) {
291        let _ = self.client.close(zx::MonotonicInstant::INFINITE);
292    }
293
294    fn getsockname(
295        &self,
296        _locked: &mut Locked<FileOpsCore>,
297        _socket: &Socket,
298    ) -> Result<SocketAddress, Errno> {
299        Ok(SocketAddress::default_for_domain(SocketDomain::Unix))
300    }
301
302    fn getpeername(
303        &self,
304        locked: &mut Locked<FileOpsCore>,
305        socket: &Socket,
306    ) -> Result<SocketAddress, Errno> {
307        self.getsockname(locked, socket)
308    }
309
310    fn setsockopt(
311        &self,
312        _locked: &mut Locked<FileOpsCore>,
313        _socket: &Socket,
314        _current_task: &CurrentTask,
315        _level: u32,
316        _optname: u32,
317        _optval: SockOptValue,
318    ) -> Result<(), Errno> {
319        error!(EOPNOTSUPP)
320    }
321
322    fn getsockopt(
323        &self,
324        _locked: &mut Locked<FileOpsCore>,
325        _socket: &Socket,
326        _current_task: &CurrentTask,
327        level: u32,
328        optname: u32,
329        _optlen: u32,
330    ) -> Result<Vec<u8>, Errno> {
331        if level != SOL_SOCKET {
332            return error!(EINVAL);
333        }
334        let data = match optname {
335            SO_LINGER => uapi::linger::default().as_bytes().to_vec(),
336            _ => return error!(EINVAL),
337        };
338
339        Ok(data)
340    }
341
342    fn to_handle(
343        &self,
344        _socket: &Socket,
345        _current_task: &CurrentTask,
346    ) -> Result<Option<zx::NullableHandle>, Errno> {
347        let (proxy, server) = zx::Channel::create();
348        self.client.clone(server.into()).map_err(|_| errno!(ECONNREFUSED))?;
349        Ok(Some(zx::NullableHandle::from(proxy).into()))
350    }
351}
352
353#[cfg(test)]
354mod tests {
355    use super::*;
356    use crate::testing::spawn_kernel_and_run;
357    use crate::vfs::socket::SocketFile;
358    use crate::vfs::{VecInputBuffer, VecOutputBuffer};
359    use fidl::endpoints::{DiscoverableProtocolMarker as _, RequestStream};
360    use futures::StreamExt;
361    use starnix_sync::Mutex;
362    use std::sync::Arc;
363    use zx::HandleBased;
364    use {fidl_fuchsia_unknown as funknown, fuchsia_async as fasync};
365
366    #[derive(Debug)]
367    struct Data {
368        bytes: Vec<u8>,
369        handles: Vec<zx::NullableHandle>,
370    }
371
372    impl Data {
373        fn try_clone(&mut self) -> Result<Self, zx::Status> {
374            let mut new_handles = vec![];
375            for handle in std::mem::take(&mut self.handles) {
376                let (new_handle, old_handle) = {
377                    let handle_type = handle.basic_info()?.object_type;
378                    match handle_type {
379                        zx::ObjectType::CHANNEL => {
380                            let channel = zx::Channel::from(handle);
381                            let client = funknown::CloneableSynchronousProxy::new(channel);
382                            let (proxy, server) = zx::Channel::create();
383                            let new_handle = client
384                                .clone(server.into())
385                                .map(|_| proxy.into())
386                                .map_err(|_| zx::Status::NOT_SUPPORTED);
387                            (new_handle, client.into_channel().into_handle())
388                        }
389                        _ => {
390                            let new_handle = handle.duplicate_handle(zx::Rights::SAME_RIGHTS);
391                            (new_handle, handle)
392                        }
393                    }
394                };
395                self.handles.push(old_handle);
396                new_handles.push(new_handle);
397            }
398            let new_handles =
399                new_handles.into_iter().collect::<Result<Vec<zx::NullableHandle>, zx::Status>>()?;
400            Ok(Self { bytes: self.bytes.clone(), handles: new_handles })
401        }
402    }
403
404    #[derive(Debug)]
405    struct UnixDomainSocketImplState {
406        _local_event: zx::EventPair,
407        remote_event: zx::EventPair,
408        buffer: Vec<Data>,
409    }
410
411    impl Default for UnixDomainSocketImplState {
412        fn default() -> Self {
413            let (_local_event, remote_event) = zx::EventPair::create();
414            Self { _local_event, remote_event, buffer: vec![] }
415        }
416    }
417
418    #[derive(Debug, Default)]
419    struct UnixDomainSocketImpl {
420        state: Mutex<UnixDomainSocketImplState>,
421        close_on_read: bool,
422    }
423
424    impl UnixDomainSocketImpl {
425        fn read(
426            &self,
427            payload: fbinder::UnixDomainSocketReadRequest,
428        ) -> Result<fbinder::UnixDomainSocketReadResponse, zx::Status> {
429            if self.close_on_read {
430                return Err(zx::Status::PEER_CLOSED);
431            }
432            let Some(count) = payload.count else {
433                return Err(zx::Status::INVALID_ARGS);
434            };
435            let Some(flags) = payload.flags else {
436                return Err(zx::Status::INVALID_ARGS);
437            };
438            let mut state = self.state.lock();
439            if state.buffer.is_empty() {
440                return Err(zx::Status::SHOULD_WAIT);
441            }
442            let mut data = if flags.contains(fbinder::ReadFlags::PEEK) {
443                state.buffer[0].try_clone()?
444            } else {
445                state.buffer.remove(0)
446            };
447
448            if state.buffer.is_empty() {
449                state.remote_event.as_handle_ref().signal(READABLE_SIGNAL, WRITABLE_SIGNAL)?;
450            }
451
452            let actual_count = data.bytes.len() as u64;
453            data.bytes.truncate(count as usize);
454
455            Ok(fbinder::UnixDomainSocketReadResponse {
456                data: Some(data.bytes),
457                data_original_length: Some(actual_count),
458                handles: Some(data.handles),
459                ..Default::default()
460            })
461        }
462
463        fn write(
464            &self,
465            payload: fbinder::UnixDomainSocketWriteRequest,
466        ) -> Result<fbinder::UnixDomainSocketWriteResponse, zx::Status> {
467            let Some(bytes) = payload.data else {
468                return Err(zx::Status::INVALID_ARGS);
469            };
470            let actual_count = bytes.len() as u64;
471            let Some(handles) = payload.handles else {
472                return Err(zx::Status::INVALID_ARGS);
473            };
474            let mut state = self.state.lock();
475            state.buffer.push(Data { bytes, handles });
476            state
477                .remote_event
478                .as_handle_ref()
479                .signal(zx::Signals::NONE, READABLE_SIGNAL | WRITABLE_SIGNAL)?;
480            Ok(fbinder::UnixDomainSocketWriteResponse {
481                actual_count: Some(actual_count),
482                ..Default::default()
483            })
484        }
485
486        async fn serve(self: &Arc<Self>, channel: zx::Channel) {
487            let stream = fbinder::UnixDomainSocketRequestStream::from_channel(
488                fasync::Channel::from_channel(channel),
489            );
490            stream
491                .for_each_concurrent(None, |message| async {
492                    match message {
493                        Ok(fbinder::UnixDomainSocketRequest::GetEvent { responder, .. }) => {
494                            let state = self.state.lock();
495                            let event = state
496                                .remote_event
497                                .duplicate_handle(zx::Rights::SAME_RIGHTS)
498                                .expect("duplicate event");
499                            responder
500                                .send(Ok(fbinder::UnixDomainSocketGetEventResponse {
501                                    event: Some(event),
502                                    ..Default::default()
503                                }))
504                                .expect("respond");
505                        }
506                        Ok(fbinder::UnixDomainSocketRequest::Read {
507                            payload, responder, ..
508                        }) => {
509                            assert!(
510                                responder
511                                    .send(self.read(payload).map_err(|e| e.into_raw()))
512                                    .is_ok()
513                            );
514                        }
515                        Ok(fbinder::UnixDomainSocketRequest::Write {
516                            payload, responder, ..
517                        }) => {
518                            assert!(
519                                responder
520                                    .send(self.write(payload).as_ref().map_err(|e| e.into_raw()))
521                                    .is_ok()
522                            );
523                        }
524                        Ok(fbinder::UnixDomainSocketRequest::Query { responder }) => {
525                            assert!(
526                                responder
527                                    .send(fbinder::UnixDomainSocketMarker::PROTOCOL_NAME.as_bytes())
528                                    .is_ok()
529                            );
530                        }
531                        Ok(fbinder::UnixDomainSocketRequest::Clone { request, .. }) => {
532                            self.serve(request.into()).await;
533                        }
534                        Ok(fbinder::UnixDomainSocketRequest::Close { responder }) => {
535                            assert!(responder.send(Ok(())).is_ok());
536                        }
537                        _ => {
538                            return;
539                        }
540                    }
541                })
542                .await;
543        }
544    }
545
546    #[::fuchsia::test]
547    async fn test_remote_uds() {
548        let (client, server) = zx::Channel::create();
549        let handle = std::thread::spawn(|| {
550            let mut executor = fasync::LocalExecutor::default();
551            executor.run_singlethreaded(async move {
552                let uds_impl = UnixDomainSocketImpl::default();
553                Arc::new(uds_impl).serve(server).await;
554            });
555        });
556        spawn_kernel_and_run(async move |locked, current_task| {
557            let original_file =
558                new_remote_file(locked, current_task, client.into(), OpenFlags::RDWR)
559                    .expect("new_remote_file");
560            assert!(original_file.node().is_sock());
561            let file = new_remote_file(
562                locked,
563                current_task,
564                original_file.to_handle(current_task).expect("to_handle").expect("has_handle"),
565                OpenFlags::RDWR,
566            )
567            .expect("new_remote_file");
568            let ancillary_data =
569                vec![AncillaryData::Unix(UnixControlData::Rights(vec![original_file]))];
570            let socket_ops = file.downcast_file::<SocketFile>().unwrap();
571            let data = "HelloWorld";
572            let mut input_buffer = VecInputBuffer::new(data.as_bytes());
573            assert_eq!(
574                socket_ops.sendmsg(
575                    locked,
576                    current_task,
577                    &file,
578                    &mut input_buffer,
579                    None,
580                    ancillary_data,
581                    SocketMessageFlags::empty()
582                ),
583                Ok(data.len())
584            );
585
586            let flags = SocketMessageFlags::CTRUNC
587                | SocketMessageFlags::TRUNC
588                | SocketMessageFlags::NOSIGNAL
589                | SocketMessageFlags::CMSG_CLOEXEC;
590
591            let mut buffer = VecOutputBuffer::new(1024);
592            let info = socket_ops
593                .recvmsg(
594                    locked,
595                    &current_task,
596                    &file,
597                    &mut buffer,
598                    flags | SocketMessageFlags::PEEK,
599                    None,
600                )
601                .expect("recvmsg");
602
603            assert_eq!(info.ancillary_data.len(), 1);
604            assert_eq!(info.message_length, data.len());
605
606            let mut buffer = VecOutputBuffer::new(1024);
607            let info = socket_ops
608                .recvmsg(locked, &current_task, &file, &mut buffer, flags, None)
609                .expect("recvmsg");
610
611            assert_eq!(info.ancillary_data.len(), 1);
612            assert_eq!(info.message_length, data.len());
613
614            let mut buffer = VecOutputBuffer::new(1024);
615            let err = socket_ops
616                .recvmsg(
617                    locked,
618                    &current_task,
619                    &file,
620                    &mut buffer,
621                    flags | SocketMessageFlags::DONTWAIT,
622                    None,
623                )
624                .unwrap_err();
625            assert_eq!(err, errno!(EAGAIN));
626        })
627        .await;
628        handle.join().expect("join");
629    }
630
631    #[::fuchsia::test]
632    async fn test_remote_uds_peer_closed() {
633        let (client, server) = zx::Channel::create();
634        let handle = std::thread::spawn(move || {
635            let mut executor = fasync::LocalExecutor::default();
636            executor.run_singlethreaded(async move {
637                let uds_impl = UnixDomainSocketImpl { close_on_read: true, ..Default::default() };
638                Arc::new(uds_impl).serve(server).await;
639            });
640        });
641        spawn_kernel_and_run(async move |locked, current_task| {
642            let file = new_remote_file(locked, current_task, client.into(), OpenFlags::RDWR)
643                .expect("new_remote_file");
644            let socket_ops = file.downcast_file::<SocketFile>().unwrap();
645
646            let mut buffer = VecOutputBuffer::new(1024);
647            let err = socket_ops
648                .recvmsg(
649                    locked,
650                    &current_task,
651                    &file,
652                    &mut buffer,
653                    SocketMessageFlags::empty(),
654                    None,
655                )
656                .unwrap_err();
657            assert_eq!(err, errno!(ECONNRESET));
658        })
659        .await;
660        handle.join().expect("join");
661    }
662}