starnix_core/fs/fuchsia/
remote_unix_domain_socket.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use crate::fs::fuchsia::{OpenFlags, new_remote_file};
6use crate::task::{
7    CurrentTask, EventHandler, FullCredentials, SignalHandler, SignalHandlerInner, WaitCanceler,
8    Waiter,
9};
10use crate::vfs::buffers::{InputBuffer, OutputBuffer};
11use crate::vfs::socket::{
12    SockOptValue, Socket, SocketAddress, SocketDomain, SocketHandle, SocketMessageFlags, SocketOps,
13    SocketPeer, SocketProtocol, SocketShutdownFlags, SocketType,
14};
15use crate::vfs::{AncillaryData, FileHandle, MessageReadInfo, UnixControlData};
16use fidl::endpoints::SynchronousProxy;
17use linux_uapi::{SO_LINGER, SOL_SOCKET};
18use starnix_sync::{FileOpsCore, Locked};
19use starnix_uapi::errors::Errno;
20use starnix_uapi::vfs::FdEvents;
21use starnix_uapi::{errno, error, from_status_like_fdio, uapi, ucred};
22use zerocopy::IntoBytes;
23use {fidl_fuchsia_io as fio, fidl_fuchsia_starnix_binder as fbinder};
24static READABLE_SIGNAL: zx::Signals =
25    zx::Signals::from_bits_retain(fio::FileSignal::READABLE.bits());
26static WRITABLE_SIGNAL: zx::Signals =
27    zx::Signals::from_bits_retain(fio::FileSignal::WRITABLE.bits());
28
29pub struct RemoteUnixDomainSocket {
30    client: fbinder::UnixDomainSocketSynchronousProxy,
31    event: zx::EventPair,
32    remote_creds: FullCredentials,
33}
34
35impl RemoteUnixDomainSocket {
36    pub fn new(channel: zx::Channel, remote_creds: FullCredentials) -> Result<Self, Errno> {
37        let client = fbinder::UnixDomainSocketSynchronousProxy::from_channel(channel);
38        let response = client
39            .get_event(
40                &fbinder::UnixDomainSocketGetEventRequest::default(),
41                zx::MonotonicInstant::INFINITE,
42            )
43            .map_err(|_| errno!(ECONNREFUSED))?
44            .map_err(|e: i32| from_status_like_fdio!(zx::Status::from_raw(e)))?;
45        let event = response.event.ok_or_else(|| errno!(ECONNREFUSED))?;
46        Ok(Self { client, event, remote_creds })
47    }
48
49    fn get_signals_from_events(events: FdEvents) -> zx::Signals {
50        let mut signals = zx::Signals::NONE;
51        if events.contains(FdEvents::POLLIN) {
52            signals |= READABLE_SIGNAL;
53        }
54        if events.contains(FdEvents::POLLOUT) {
55            signals |= WRITABLE_SIGNAL;
56        }
57        signals
58    }
59
60    fn get_events_from_signals(signals: zx::Signals) -> FdEvents {
61        let mut events = FdEvents::empty();
62        if signals.contains(READABLE_SIGNAL) {
63            events |= FdEvents::POLLIN;
64        }
65        if signals.contains(WRITABLE_SIGNAL) {
66            events |= FdEvents::POLLOUT;
67        }
68        events
69    }
70
71    /// Perform an action using the credentials of the remote task.
72    fn with_remote_creds<F, R>(&self, current_task: &CurrentTask, f: F) -> Result<R, Errno>
73    where
74        F: FnOnce() -> Result<R, Errno>,
75    {
76        current_task.override_creds(self.remote_creds.clone(), f)
77    }
78}
79
80impl SocketOps for RemoteUnixDomainSocket {
81    fn get_socket_info(&self) -> Result<(SocketDomain, SocketType, SocketProtocol), Errno> {
82        Ok((SocketDomain::Unix, SocketType::Datagram, SocketProtocol::from_raw(0)))
83    }
84
85    fn connect(
86        &self,
87        _locked: &mut Locked<FileOpsCore>,
88        _socket: &SocketHandle,
89        _current_task: &CurrentTask,
90        _peer: SocketPeer,
91    ) -> Result<(), Errno> {
92        error!(EISCONN)
93    }
94
95    fn listen(
96        &self,
97        _locked: &mut Locked<FileOpsCore>,
98        _socket: &Socket,
99        _backlog: i32,
100        _credentials: ucred,
101    ) -> Result<(), Errno> {
102        error!(EOPNOTSUPP)
103    }
104
105    fn accept(
106        &self,
107        _locked: &mut Locked<FileOpsCore>,
108        _socket: &Socket,
109        _current_task: &CurrentTask,
110    ) -> Result<SocketHandle, Errno> {
111        error!(EOPNOTSUPP)
112    }
113
114    fn bind(
115        &self,
116        _locked: &mut Locked<FileOpsCore>,
117        _socket: &Socket,
118        _current_task: &CurrentTask,
119        _socket_address: SocketAddress,
120    ) -> Result<(), Errno> {
121        error!(EOPNOTSUPP)
122    }
123
124    fn read(
125        &self,
126        locked: &mut Locked<FileOpsCore>,
127        _socket: &Socket,
128        current_task: &CurrentTask,
129        data: &mut dyn OutputBuffer,
130        flags: SocketMessageFlags,
131    ) -> Result<MessageReadInfo, Errno> {
132        if self.client.is_closed().map_err(|_| errno!(ECONNREFUSED))? {
133            return error!(ECONNREFUSED);
134        }
135        let mut read_flags = fbinder::ReadFlags::empty();
136        if flags.contains(SocketMessageFlags::PEEK) {
137            read_flags |= fbinder::ReadFlags::PEEK;
138        }
139
140        let response = self
141            .client
142            .read(
143                &fbinder::UnixDomainSocketReadRequest {
144                    count: Some(data.available() as u64),
145                    flags: Some(read_flags),
146                    ..Default::default()
147                },
148                zx::MonotonicInstant::INFINITE,
149            )
150            .map_err(|_| errno!(ECONNREFUSED))?
151            .map_err(|e: i32| {
152                let status = zx::Status::from_raw(e);
153                if status == zx::Status::PEER_CLOSED {
154                    errno!(ECONNRESET)
155                } else {
156                    from_status_like_fdio!(status)
157                }
158            })?;
159
160        let written =
161            if let Some(received_data) = response.data { data.write(&received_data)? } else { 0 };
162
163        let mut file_handles: Vec<FileHandle> = vec![];
164        if let Some(handles) = response.handles {
165            // Use the remote task's credentials to create the remote_file object. This ensures
166            // that the SID associated to the fd is set to the correct value.
167            self.with_remote_creds(current_task, || {
168                for handle in handles {
169                    file_handles.push(new_remote_file(
170                        locked,
171                        current_task,
172                        handle,
173                        OpenFlags::RDWR,
174                    )?);
175                }
176                Ok(())
177            })?;
178        }
179        let ancillary_data = vec![AncillaryData::Unix(UnixControlData::Rights(file_handles))];
180
181        let message_length = response.data_original_length.unwrap_or(written as u64) as usize;
182
183        Ok(MessageReadInfo { bytes_read: written, message_length, address: None, ancillary_data })
184    }
185
186    fn write(
187        &self,
188        _locked: &mut Locked<FileOpsCore>,
189        _socket: &Socket,
190        current_task: &CurrentTask,
191        data: &mut dyn InputBuffer,
192        _dest_address: &mut Option<SocketAddress>,
193        ancillary_data: &mut Vec<AncillaryData>,
194    ) -> Result<usize, Errno> {
195        if self.client.is_closed().map_err(|_| errno!(ECONNREFUSED))? {
196            return error!(ECONNREFUSED);
197        }
198
199        let mut handles: Vec<zx::NullableHandle> = vec![];
200        for data in ancillary_data {
201            match data {
202                AncillaryData::Unix(UnixControlData::Rights(file_handles)) => {
203                    // Access the served files with the credentials of the remote end.
204                    self.with_remote_creds(current_task, || {
205                        for file_handle in file_handles {
206                            let Some(handle) = file_handle.to_handle(current_task)? else {
207                                return error!(EINVAL);
208                            };
209                            handles.push(handle);
210                        }
211                        Ok(())
212                    })?;
213                }
214                _ => return error!(EINVAL),
215            }
216        }
217
218        let bytes = data.read_all()?;
219
220        let response = self
221            .client
222            .write(
223                fbinder::UnixDomainSocketWriteRequest {
224                    data: Some(bytes),
225                    handles: Some(handles),
226                    ..Default::default()
227                },
228                zx::MonotonicInstant::INFINITE,
229            )
230            .map_err(|_| errno!(ECONNREFUSED))?
231            .map_err(|e: i32| from_status_like_fdio!(zx::Status::from_raw(e)))?;
232
233        let written = response.actual_count.unwrap_or(0);
234        Ok(written as usize)
235    }
236
237    fn wait_async(
238        &self,
239        _locked: &mut Locked<FileOpsCore>,
240        _socket: &Socket,
241        _current_task: &CurrentTask,
242        waiter: &Waiter,
243        events: FdEvents,
244        handler: EventHandler,
245    ) -> WaitCanceler {
246        let signal_handler = SignalHandler {
247            inner: SignalHandlerInner::ZxHandle(Self::get_events_from_signals),
248            event_handler: handler,
249            err_code: None,
250        };
251        let canceler = waiter
252            .wake_on_zircon_signals(
253                &self.event,
254                Self::get_signals_from_events(events),
255                signal_handler,
256            )
257            .unwrap();
258        WaitCanceler::new_port(canceler)
259    }
260
261    fn query_events(
262        &self,
263        _locked: &mut Locked<FileOpsCore>,
264        _socket: &Socket,
265        _current_task: &CurrentTask,
266    ) -> Result<FdEvents, Errno> {
267        let signals = self
268            .event
269            .as_handle_ref()
270            .wait_one(zx::Signals::NONE, zx::MonotonicInstant::INFINITE_PAST)
271            .map_err(|e| from_status_like_fdio!(e))?;
272        Ok(Self::get_events_from_signals(signals))
273    }
274
275    fn shutdown(
276        &self,
277        _locked: &mut Locked<FileOpsCore>,
278        _socket: &Socket,
279        _how: SocketShutdownFlags,
280    ) -> Result<(), Errno> {
281        Ok(())
282    }
283
284    fn close(
285        &self,
286        _locked: &mut Locked<FileOpsCore>,
287        _current_task: &CurrentTask,
288        _socket: &Socket,
289    ) {
290        let _ = self.client.close(zx::MonotonicInstant::INFINITE);
291    }
292
293    fn getsockname(
294        &self,
295        _locked: &mut Locked<FileOpsCore>,
296        _socket: &Socket,
297    ) -> Result<SocketAddress, Errno> {
298        Ok(SocketAddress::default_for_domain(SocketDomain::Unix))
299    }
300
301    fn getpeername(
302        &self,
303        locked: &mut Locked<FileOpsCore>,
304        socket: &Socket,
305    ) -> Result<SocketAddress, Errno> {
306        self.getsockname(locked, socket)
307    }
308
309    fn setsockopt(
310        &self,
311        _locked: &mut Locked<FileOpsCore>,
312        _socket: &Socket,
313        _current_task: &CurrentTask,
314        _level: u32,
315        _optname: u32,
316        _optval: SockOptValue,
317    ) -> Result<(), Errno> {
318        error!(EOPNOTSUPP)
319    }
320
321    fn getsockopt(
322        &self,
323        _locked: &mut Locked<FileOpsCore>,
324        _socket: &Socket,
325        _current_task: &CurrentTask,
326        level: u32,
327        optname: u32,
328        _optlen: u32,
329    ) -> Result<Vec<u8>, Errno> {
330        if level != SOL_SOCKET {
331            return error!(EINVAL);
332        }
333        let data = match optname {
334            SO_LINGER => uapi::linger::default().as_bytes().to_vec(),
335            _ => return error!(EINVAL),
336        };
337
338        Ok(data)
339    }
340
341    fn to_handle(
342        &self,
343        _socket: &Socket,
344        _current_task: &CurrentTask,
345    ) -> Result<Option<zx::NullableHandle>, Errno> {
346        let (proxy, server) = zx::Channel::create();
347        self.client.clone(server.into()).map_err(|_| errno!(ECONNREFUSED))?;
348        Ok(Some(zx::NullableHandle::from(proxy).into()))
349    }
350}
351
352#[cfg(test)]
353mod tests {
354    use super::*;
355    use crate::testing::spawn_kernel_and_run;
356    use crate::vfs::socket::SocketFile;
357    use crate::vfs::{VecInputBuffer, VecOutputBuffer};
358    use fidl::endpoints::{DiscoverableProtocolMarker as _, RequestStream};
359    use futures::StreamExt;
360    use starnix_sync::Mutex;
361    use std::sync::Arc;
362    use zx::HandleBased;
363    use {fidl_fuchsia_unknown as funknown, fuchsia_async as fasync};
364
365    #[derive(Debug)]
366    struct Data {
367        bytes: Vec<u8>,
368        handles: Vec<zx::NullableHandle>,
369    }
370
371    impl Data {
372        fn try_clone(&mut self) -> Result<Self, zx::Status> {
373            let mut new_handles = vec![];
374            for handle in std::mem::take(&mut self.handles) {
375                let (new_handle, old_handle) = {
376                    let handle_type = handle.basic_info()?.object_type;
377                    match handle_type {
378                        zx::ObjectType::CHANNEL => {
379                            let channel = zx::Channel::from(handle);
380                            let client = funknown::CloneableSynchronousProxy::new(channel);
381                            let (proxy, server) = zx::Channel::create();
382                            let new_handle = client
383                                .clone(server.into())
384                                .map(|_| proxy.into())
385                                .map_err(|_| zx::Status::NOT_SUPPORTED);
386                            (new_handle, client.into_channel().into_handle())
387                        }
388                        _ => {
389                            let new_handle = handle.duplicate_handle(zx::Rights::SAME_RIGHTS);
390                            (new_handle, handle)
391                        }
392                    }
393                };
394                self.handles.push(old_handle);
395                new_handles.push(new_handle);
396            }
397            let new_handles =
398                new_handles.into_iter().collect::<Result<Vec<zx::NullableHandle>, zx::Status>>()?;
399            Ok(Self { bytes: self.bytes.clone(), handles: new_handles })
400        }
401    }
402
403    #[derive(Debug)]
404    struct UnixDomainSocketImplState {
405        _local_event: zx::EventPair,
406        remote_event: zx::EventPair,
407        buffer: Vec<Data>,
408    }
409
410    impl Default for UnixDomainSocketImplState {
411        fn default() -> Self {
412            let (_local_event, remote_event) = zx::EventPair::create();
413            Self { _local_event, remote_event, buffer: vec![] }
414        }
415    }
416
417    #[derive(Debug, Default)]
418    struct UnixDomainSocketImpl {
419        state: Mutex<UnixDomainSocketImplState>,
420        close_on_read: bool,
421    }
422
423    impl UnixDomainSocketImpl {
424        fn read(
425            &self,
426            payload: fbinder::UnixDomainSocketReadRequest,
427        ) -> Result<fbinder::UnixDomainSocketReadResponse, zx::Status> {
428            if self.close_on_read {
429                return Err(zx::Status::PEER_CLOSED);
430            }
431            let Some(count) = payload.count else {
432                return Err(zx::Status::INVALID_ARGS);
433            };
434            let Some(flags) = payload.flags else {
435                return Err(zx::Status::INVALID_ARGS);
436            };
437            let mut state = self.state.lock();
438            if state.buffer.is_empty() {
439                return Err(zx::Status::SHOULD_WAIT);
440            }
441            let mut data = if flags.contains(fbinder::ReadFlags::PEEK) {
442                state.buffer[0].try_clone()?
443            } else {
444                state.buffer.remove(0)
445            };
446
447            if state.buffer.is_empty() {
448                state.remote_event.as_handle_ref().signal(READABLE_SIGNAL, WRITABLE_SIGNAL)?;
449            }
450
451            let actual_count = data.bytes.len() as u64;
452            data.bytes.truncate(count as usize);
453
454            Ok(fbinder::UnixDomainSocketReadResponse {
455                data: Some(data.bytes),
456                data_original_length: Some(actual_count),
457                handles: Some(data.handles),
458                ..Default::default()
459            })
460        }
461
462        fn write(
463            &self,
464            payload: fbinder::UnixDomainSocketWriteRequest,
465        ) -> Result<fbinder::UnixDomainSocketWriteResponse, zx::Status> {
466            let Some(bytes) = payload.data else {
467                return Err(zx::Status::INVALID_ARGS);
468            };
469            let actual_count = bytes.len() as u64;
470            let Some(handles) = payload.handles else {
471                return Err(zx::Status::INVALID_ARGS);
472            };
473            let mut state = self.state.lock();
474            state.buffer.push(Data { bytes, handles });
475            state
476                .remote_event
477                .as_handle_ref()
478                .signal(zx::Signals::NONE, READABLE_SIGNAL | WRITABLE_SIGNAL)?;
479            Ok(fbinder::UnixDomainSocketWriteResponse {
480                actual_count: Some(actual_count),
481                ..Default::default()
482            })
483        }
484
485        async fn serve(self: &Arc<Self>, channel: zx::Channel) {
486            let stream = fbinder::UnixDomainSocketRequestStream::from_channel(
487                fasync::Channel::from_channel(channel),
488            );
489            stream
490                .for_each_concurrent(None, |message| async {
491                    match message {
492                        Ok(fbinder::UnixDomainSocketRequest::GetEvent { responder, .. }) => {
493                            let state = self.state.lock();
494                            let event = state
495                                .remote_event
496                                .duplicate_handle(zx::Rights::SAME_RIGHTS)
497                                .expect("duplicate event");
498                            responder
499                                .send(Ok(fbinder::UnixDomainSocketGetEventResponse {
500                                    event: Some(event),
501                                    ..Default::default()
502                                }))
503                                .expect("respond");
504                        }
505                        Ok(fbinder::UnixDomainSocketRequest::Read {
506                            payload, responder, ..
507                        }) => {
508                            assert!(
509                                responder
510                                    .send(self.read(payload).map_err(|e| e.into_raw()))
511                                    .is_ok()
512                            );
513                        }
514                        Ok(fbinder::UnixDomainSocketRequest::Write {
515                            payload, responder, ..
516                        }) => {
517                            assert!(
518                                responder
519                                    .send(self.write(payload).as_ref().map_err(|e| e.into_raw()))
520                                    .is_ok()
521                            );
522                        }
523                        Ok(fbinder::UnixDomainSocketRequest::Query { responder }) => {
524                            assert!(
525                                responder
526                                    .send(fbinder::UnixDomainSocketMarker::PROTOCOL_NAME.as_bytes())
527                                    .is_ok()
528                            );
529                        }
530                        Ok(fbinder::UnixDomainSocketRequest::Clone { request, .. }) => {
531                            self.serve(request.into()).await;
532                        }
533                        Ok(fbinder::UnixDomainSocketRequest::Close { responder }) => {
534                            assert!(responder.send(Ok(())).is_ok());
535                        }
536                        _ => {
537                            return;
538                        }
539                    }
540                })
541                .await;
542        }
543    }
544
545    #[::fuchsia::test]
546    async fn test_remote_uds() {
547        let (client, server) = zx::Channel::create();
548        let handle = std::thread::spawn(|| {
549            let mut executor = fasync::LocalExecutor::default();
550            executor.run_singlethreaded(async move {
551                let uds_impl = UnixDomainSocketImpl::default();
552                Arc::new(uds_impl).serve(server).await;
553            });
554        });
555        spawn_kernel_and_run(async move |locked, current_task| {
556            let original_file =
557                new_remote_file(locked, current_task, client.into(), OpenFlags::RDWR)
558                    .expect("new_remote_file");
559            assert!(original_file.node().is_sock());
560            let file = new_remote_file(
561                locked,
562                current_task,
563                original_file.to_handle(current_task).expect("to_handle").expect("has_handle"),
564                OpenFlags::RDWR,
565            )
566            .expect("new_remote_file");
567            let ancillary_data =
568                vec![AncillaryData::Unix(UnixControlData::Rights(vec![original_file]))];
569            let socket_ops = file.downcast_file::<SocketFile>().unwrap();
570            let data = "HelloWorld";
571            let mut input_buffer = VecInputBuffer::new(data.as_bytes());
572            assert_eq!(
573                socket_ops.sendmsg(
574                    locked,
575                    current_task,
576                    &file,
577                    &mut input_buffer,
578                    None,
579                    ancillary_data,
580                    SocketMessageFlags::empty()
581                ),
582                Ok(data.len())
583            );
584
585            let flags = SocketMessageFlags::CTRUNC
586                | SocketMessageFlags::TRUNC
587                | SocketMessageFlags::NOSIGNAL
588                | SocketMessageFlags::CMSG_CLOEXEC;
589
590            let mut buffer = VecOutputBuffer::new(1024);
591            let info = socket_ops
592                .recvmsg(
593                    locked,
594                    &current_task,
595                    &file,
596                    &mut buffer,
597                    flags | SocketMessageFlags::PEEK,
598                    None,
599                )
600                .expect("recvmsg");
601
602            assert_eq!(info.ancillary_data.len(), 1);
603            assert_eq!(info.message_length, data.len());
604
605            let mut buffer = VecOutputBuffer::new(1024);
606            let info = socket_ops
607                .recvmsg(locked, &current_task, &file, &mut buffer, flags, None)
608                .expect("recvmsg");
609
610            assert_eq!(info.ancillary_data.len(), 1);
611            assert_eq!(info.message_length, data.len());
612
613            let mut buffer = VecOutputBuffer::new(1024);
614            let err = socket_ops
615                .recvmsg(
616                    locked,
617                    &current_task,
618                    &file,
619                    &mut buffer,
620                    flags | SocketMessageFlags::DONTWAIT,
621                    None,
622                )
623                .unwrap_err();
624            assert_eq!(err, errno!(EAGAIN));
625        })
626        .await;
627        handle.join().expect("join");
628    }
629
630    #[::fuchsia::test]
631    async fn test_remote_uds_peer_closed() {
632        let (client, server) = zx::Channel::create();
633        let handle = std::thread::spawn(move || {
634            let mut executor = fasync::LocalExecutor::default();
635            executor.run_singlethreaded(async move {
636                let uds_impl = UnixDomainSocketImpl { close_on_read: true, ..Default::default() };
637                Arc::new(uds_impl).serve(server).await;
638            });
639        });
640        spawn_kernel_and_run(async move |locked, current_task| {
641            let file = new_remote_file(locked, current_task, client.into(), OpenFlags::RDWR)
642                .expect("new_remote_file");
643            let socket_ops = file.downcast_file::<SocketFile>().unwrap();
644
645            let mut buffer = VecOutputBuffer::new(1024);
646            let err = socket_ops
647                .recvmsg(
648                    locked,
649                    &current_task,
650                    &file,
651                    &mut buffer,
652                    SocketMessageFlags::empty(),
653                    None,
654                )
655                .unwrap_err();
656            assert_eq!(err, errno!(ECONNRESET));
657        })
658        .await;
659        handle.join().expect("join");
660    }
661}