overnet_core/proxy/handle/
socket.rs

1// Copyright 2020 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use super::signals::Collector;
6use super::{
7    IntoProxied, Message, Proxyable, ProxyableRW, ReadValue, RouterHolder, Serializer, IO,
8};
9use crate::peer::PeerConnRef;
10use anyhow::Error;
11use fidl::{AsHandleRef, AsyncSocket, HandleBased, Peered, Signals};
12use futures::io::{AsyncRead, AsyncWrite};
13use futures::ready;
14use std::pin::Pin;
15use std::task::{Context, Poll};
16use zx_status;
17
18pub(crate) struct Socket {
19    socket: AsyncSocket,
20}
21
22impl std::fmt::Debug for Socket {
23    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24        "Socket".fmt(f)
25    }
26}
27
28impl Proxyable for Socket {
29    type Message = SocketMessage;
30
31    fn from_fidl_handle(hdl: fidl::Handle) -> Result<Self, Error> {
32        Ok(fidl::Socket::from_handle(hdl).into_proxied()?)
33    }
34
35    fn into_fidl_handle(self) -> Result<fidl::Handle, Error> {
36        Ok(self.socket.into_zx_socket().into_handle())
37    }
38
39    fn signal_peer(&self, clear: Signals, set: Signals) -> Result<(), Error> {
40        self.socket.as_ref().signal_peer(clear, set)?;
41        Ok(())
42    }
43}
44
45impl<'a> ProxyableRW<'a> for Socket {
46    type Reader = SocketReader<'a>;
47    type Writer = SocketWriter;
48}
49
50impl IntoProxied for fidl::Socket {
51    type Proxied = Socket;
52    fn into_proxied(self) -> Result<Socket, Error> {
53        Ok(Socket { socket: AsyncSocket::from_socket(self) })
54    }
55}
56
57pub(crate) struct SocketReader<'a> {
58    collector: Collector<'a>,
59}
60
61impl<'a> IO<'a> for SocketReader<'a> {
62    type Proxyable = Socket;
63    type Output = ReadValue;
64    fn new() -> Self {
65        SocketReader { collector: Default::default() }
66    }
67    fn poll_io(
68        &mut self,
69        msg: &mut SocketMessage,
70        socket: &'a Socket,
71        fut_ctx: &mut Context<'_>,
72    ) -> Poll<Result<ReadValue, zx_status::Status>> {
73        const MIN_READ_LEN: usize = 65536;
74        if msg.0.len() < MIN_READ_LEN {
75            msg.0.resize(MIN_READ_LEN, 0u8);
76        }
77        let read_result = (|| {
78            let n = ready!(Pin::new(&mut &socket.socket).poll_read(fut_ctx, &mut msg.0))?;
79            if n == 0 {
80                return Poll::Ready(Err(zx_status::Status::PEER_CLOSED));
81            }
82            msg.0.truncate(n);
83            Poll::Ready(Ok(()))
84        })();
85        self.collector.after_read(fut_ctx, socket.socket.as_handle_ref(), read_result, false)
86    }
87}
88
89pub(crate) struct SocketWriter;
90
91impl IO<'_> for SocketWriter {
92    type Proxyable = Socket;
93    type Output = ();
94    fn new() -> Self {
95        SocketWriter
96    }
97    fn poll_io(
98        &mut self,
99        msg: &mut SocketMessage,
100        socket: &Socket,
101        fut_ctx: &mut Context<'_>,
102    ) -> Poll<Result<(), zx_status::Status>> {
103        while !msg.0.is_empty() {
104            let n = ready!(Pin::new(&mut &socket.socket).poll_write(fut_ctx, &msg.0))?;
105            msg.0.drain(..n);
106        }
107        Poll::Ready(Ok(()))
108    }
109}
110
111#[derive(Default, PartialEq)]
112pub(crate) struct SocketMessage(Vec<u8>);
113
114impl std::fmt::Debug for SocketMessage {
115    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
116        self.0.fmt(f)
117    }
118}
119
120impl Message for SocketMessage {
121    type Parser = SocketMessageSerializer;
122    type Serializer = SocketMessageSerializer;
123}
124
125#[derive(Debug)]
126pub(crate) struct SocketMessageSerializer;
127
128impl Serializer for SocketMessageSerializer {
129    type Message = SocketMessage;
130    fn new() -> SocketMessageSerializer {
131        SocketMessageSerializer
132    }
133    fn poll_ser(
134        &mut self,
135        msg: &mut SocketMessage,
136        bytes: &mut Vec<u8>,
137        _: PeerConnRef<'_>,
138        _: &mut RouterHolder<'_>,
139        _: &mut Context<'_>,
140    ) -> Poll<Result<(), Error>> {
141        std::mem::swap(bytes, &mut msg.0);
142        Poll::Ready(Ok(()))
143    }
144}
145
146#[cfg(test)]
147mod tests {
148    use super::*;
149    use futures::AsyncReadExt as _;
150
151    #[fuchsia::test]
152    async fn stream_socket_partial_write() {
153        let (tx, rx) = fidl::Socket::create_stream();
154        let socket = tx.into_proxied().expect("create proxied socket");
155
156        const KERNEL_BUF_SIZE: usize = 257024;
157        const EXPECTED_DATA: u8 = 0xff;
158        const EXPECTED_LEN: usize = KERNEL_BUF_SIZE * 2;
159
160        let mut writer = SocketWriter::new();
161        let mut msg = SocketMessage(vec![EXPECTED_DATA; EXPECTED_LEN]);
162        // Write more than the size of the underlying kernel buffer into the
163        // proxied socket to exercise that overnet handles partial writes to the
164        // zircon socket correctly.
165        fuchsia_async::Task::spawn(async {
166            futures::future::poll_fn(move |cx| writer.poll_io(&mut msg, &socket, cx))
167                .await
168                .expect("write to socket")
169        })
170        .detach();
171
172        let mut data = vec![0u8; EXPECTED_LEN];
173        let mut rx = fuchsia_async::Socket::from_socket(rx);
174        rx.read_exact(&mut data).await.expect("read from socket");
175        assert_eq!(data, vec![EXPECTED_DATA; EXPECTED_LEN]);
176    }
177}