1use futures::channel::{mpsc, oneshot};
6use futures::lock::{Mutex, OwnedMutexGuard};
7use log::{debug, trace, warn};
8use std::collections::HashMap;
9use std::collections::hash_map::Entry;
10use std::future::Future;
11use std::io::{Error, ErrorKind};
12use std::ops::DerefMut;
13use std::pin::Pin;
14use std::sync::Arc;
15use std::task::{Context, Poll, Waker, ready};
16
17use fuchsia_async::Scope;
18use futures::io::{ReadHalf, WriteHalf};
19use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, FutureExt, SinkExt, StreamExt};
20
21use crate::connection::overflow_writer::OverflowHandleFut;
22use crate::{
23 Address, Header, Packet, PacketType, ProtocolVersion, ShutdownError, UsbPacketBuilder,
24 UsbPacketFiller, WritePacketErrorExt,
25};
26
27mod overflow_writer;
28mod pause_state;
29
30use overflow_writer::OverflowWriter;
31use pause_state::PauseState;
32
33pub trait PacketBuffer: DerefMut<Target = [u8]> + Send + Unpin + 'static {}
35impl<T> PacketBuffer for T where T: DerefMut<Target = [u8]> + Send + Unpin + 'static {}
36
37#[derive(Copy, Clone, PartialEq, Eq)]
38enum PausePacket {
39 Pause,
40 UnPause,
41}
42
43impl PausePacket {
44 fn bytes(&self) -> [u8; 1] {
45 match self {
46 PausePacket::Pause => [1],
47 PausePacket::UnPause => [0],
48 }
49 }
50}
51
52pub struct ReadyConnect<B, S> {
55 connections: Arc<fuchsia_sync::Mutex<HashMap<Address, VsockConnection<S>>>>,
56 packet_filler: Arc<UsbPacketFiller<B>>,
57 address: Address,
58}
59
60impl<B: PacketBuffer, S: AsyncRead + AsyncWrite + Send + 'static> ReadyConnect<B, S> {
61 pub async fn finish_connect(self, socket: S) {
63 let (read_socket, write_socket) = socket.split();
64 let writer = {
65 let conns = self.connections.lock();
66 let Some(conn) = conns.get(&self.address) else {
67 warn!("Connection state was missing after connection success!");
68 return;
69 };
70 let VsockConnectionState::Connected { writer, reader_scope, pause_state, .. } =
71 &conn.state
72 else {
73 warn!("Connection state was invalid after connection success!");
74 return;
75 };
76 reader_scope.spawn(Connection::<B, S>::run_socket(
77 read_socket,
78 self.address,
79 self.packet_filler,
80 Arc::clone(pause_state),
81 ));
82 Arc::clone(writer)
83 };
84 let mut writer = writer.lock().await;
85 let ConnectionStateWriter::NotYetAvailable(wakers) = std::mem::replace(
86 &mut *writer,
87 ConnectionStateWriter::Available(OverflowWriter::new(write_socket)),
88 ) else {
89 unreachable!("Connection completed multiple times!")
90 };
91
92 wakers.into_iter().for_each(Waker::wake);
93 }
94}
95
96pub struct Connection<B, S> {
109 control_socket_writer: Option<Mutex<WriteHalf<S>>>,
110 packet_filler: Arc<UsbPacketFiller<B>>,
111 protocol_version: ProtocolVersion,
112 connections: Arc<fuchsia_sync::Mutex<HashMap<Address, VsockConnection<S>>>>,
113 incoming_requests_tx: mpsc::Sender<ConnectionRequest>,
114 task_scope: Scope,
115}
116
117impl<B: PacketBuffer, S: AsyncRead + AsyncWrite + Send + 'static> Connection<B, S> {
118 pub fn new(
125 protocol_version: ProtocolVersion,
126 control_socket: Option<S>,
127 incoming_requests_tx: mpsc::Sender<ConnectionRequest>,
128 ) -> Self {
129 let packet_filler = Arc::new(UsbPacketFiller::default());
130 let connections = Default::default();
131 let task_scope = Scope::new_with_name("vsock_usb");
132 let control_socket_writer = control_socket.map(|control_socket| {
133 let (control_socket_reader, control_socket_writer) = control_socket.split();
134 task_scope.spawn(Self::run_socket(
135 control_socket_reader,
136 Address::default(),
137 packet_filler.clone(),
138 PauseState::new(),
139 ));
140 Mutex::new(control_socket_writer)
141 });
142 Self {
143 control_socket_writer,
144 packet_filler,
145 connections,
146 incoming_requests_tx,
147 protocol_version,
148 task_scope,
149 }
150 }
151
152 async fn send_close_packet(address: &Address, usb_packet_filler: &Arc<UsbPacketFiller<B>>) {
153 let header = &mut Header::new(PacketType::Finish);
154 header.set_address(address);
155 let _: Result<_, ShutdownError> = usb_packet_filler
156 .write_vsock_packet(&Packet { header, payload: &[] })
157 .await
158 .expect_right_size("Finish packet should never be too big");
159 }
160
161 async fn run_socket(
162 mut reader: ReadHalf<S>,
163 address: Address,
164 usb_packet_filler: Arc<UsbPacketFiller<B>>,
165 pause_state: Arc<PauseState>,
166 ) {
167 let mut buf = [0; 4096];
168 loop {
169 log::trace!("reading from control socket");
170 let read = match pause_state.while_unpaused(reader.read(&mut buf)).await {
171 Ok(0) => {
172 if !address.is_zeros() {
173 Self::send_close_packet(&address, &usb_packet_filler).await;
174 }
175 return;
176 }
177 Ok(read) => read,
178 Err(err) => {
179 if address.is_zeros() {
180 log::error!("Error reading usb socket: {err:?}");
181 } else {
182 Self::send_close_packet(&address, &usb_packet_filler).await;
183 }
184 return;
185 }
186 };
187 log::trace!("writing {read} bytes to vsock packet");
188 if usb_packet_filler.write_vsock_data_all(&address, &buf[..read]).await.is_err() {
189 log::trace!("transport shut down during read");
190 return;
191 }
192 log::trace!("wrote {read} bytes to vsock packet");
193 }
194 }
195
196 fn set_connection(
197 &self,
198 address: Address,
199 state: VsockConnectionState<S>,
200 ) -> Result<(), Error> {
201 let mut connections = self.connections.lock();
202 if !connections.contains_key(&address) {
203 connections.insert(address.clone(), VsockConnection { _address: address, state });
204 Ok(())
205 } else {
206 Err(Error::other(format!("connection on address {address:?} already set")))
207 }
208 }
209
210 pub async fn send_empty_echo(&self) {
213 debug!("Sending empty echo packet");
214 let header = &mut Header::new(PacketType::Echo);
215 let _: Result<_, ShutdownError> = self
216 .packet_filler
217 .write_vsock_packet(&Packet { header, payload: &[] })
218 .await
219 .expect_right_size(
220 "empty echo packet should never be too large to fit in a usb packet",
221 );
222 }
223
224 pub async fn connect(&self, addr: Address, socket: S) -> Result<ConnectionState, Error> {
229 let (ready, state) = self.connect_late(addr).await?;
230 ready.finish_connect(socket).await;
231 Ok(state)
232 }
233
234 pub async fn connect_late(
241 &self,
242 addr: Address,
243 ) -> Result<(ReadyConnect<B, S>, ConnectionState), Error> {
244 let (connected_tx, connected_rx) = oneshot::channel();
245
246 self.set_connection(addr.clone(), VsockConnectionState::ConnectingOutgoing(connected_tx))?;
247
248 let header = &mut Header::new(PacketType::Connect);
249 header.set_address(&addr);
250 self.packet_filler
251 .write_vsock_packet(&Packet { header, payload: &[] })
252 .await
253 .assert_right_size()?;
254 let Ok(conn_state) = connected_rx.await else {
255 return Err(Error::other("Accept was never received for {addr:?}"));
256 };
257
258 Ok((
259 ReadyConnect {
260 connections: Arc::clone(&self.connections),
261 packet_filler: Arc::clone(&self.packet_filler),
262 address: addr,
263 },
264 conn_state,
265 ))
266 }
267
268 pub async fn close(&self, address: &Address) {
270 Self::send_close_packet(address, &self.packet_filler).await
271 }
272
273 pub async fn reset(&self, address: &Address) -> Result<(), Error> {
275 reset(address, &self.connections, &self.packet_filler).await
276 }
277
278 pub async fn accept(
282 &self,
283 request: ConnectionRequest,
284 socket: S,
285 ) -> Result<ConnectionState, Error> {
286 let (ready, state) = self.accept_late(request).await?;
287 ready.finish_connect(socket).await;
288 Ok(state)
289 }
290
291 pub async fn accept_late(
295 &self,
296 request: ConnectionRequest,
297 ) -> Result<(ReadyConnect<B, S>, ConnectionState), Error> {
298 let address = request.address;
299 let notify_closed_rx;
300 if let Some(conn) = self.connections.lock().get_mut(&address) {
301 let VsockConnectionState::ConnectingIncoming = &conn.state else {
302 return Err(Error::other(format!(
303 "Attempted to accept connection that was not waiting at {address:?}"
304 )));
305 };
306
307 let notify_closed = mpsc::channel(2);
308 notify_closed_rx = notify_closed.1;
309 let notify_closed = notify_closed.0;
310 let pause_state = PauseState::new();
311
312 let reader_scope = Scope::new_with_name("connection-reader");
313
314 conn.state = VsockConnectionState::Connected {
315 writer: Arc::new(Mutex::new(ConnectionStateWriter::NotYetAvailable(Vec::new()))),
316 reader_scope,
317 notify_closed,
318 pause_state,
319 };
320 } else {
321 return Err(Error::other(format!(
322 "Attempting to accept connection that did not exist at {address:?}"
323 )));
324 }
325 let header = &mut Header::new(PacketType::Accept);
326 header.set_address(&address);
327 self.packet_filler
328 .write_vsock_packet(&Packet { header, payload: &[] })
329 .await
330 .assert_right_size()?;
331 Ok((
332 ReadyConnect {
333 connections: Arc::clone(&self.connections),
334 packet_filler: Arc::clone(&self.packet_filler),
335 address,
336 },
337 ConnectionState(notify_closed_rx),
338 ))
339 }
340
341 pub async fn reject(&self, request: ConnectionRequest) -> Result<(), Error> {
343 let address = request.address;
344 match self.connections.lock().entry(address.clone()) {
345 Entry::Occupied(entry) => {
346 let VsockConnectionState::ConnectingIncoming = &entry.get().state else {
347 return Err(Error::other(format!(
348 "Attempted to reject connection that was not waiting at {address:?}"
349 )));
350 };
351 entry.remove();
352 }
353 Entry::Vacant(_) => {
354 return Err(Error::other(format!(
355 "Attempted to reject connection that was not waiting at {address:?}"
356 )));
357 }
358 }
359
360 let header = &mut Header::new(PacketType::Reset);
361 header.set_address(&address);
362 self.packet_filler
363 .write_vsock_packet(&Packet { header, payload: &[] })
364 .await
365 .expect_right_size("accept packet should never be too large for packet buffer")?;
366 Ok(())
367 }
368
369 async fn handle_data_packet(&self, address: Address, payload: &[u8]) -> Result<(), Error> {
370 if address.is_zeros() {
372 if let Some(writer) = self.control_socket_writer.as_ref() {
373 writer.lock().await.write_all(payload).await?;
374 } else {
375 trace!("Discarding {} bytes of data sent to control socket", payload.len());
376 }
377 Ok(())
378 } else {
379 let payload_socket;
380 if let Some(conn) = self.connections.lock().get_mut(&address) {
381 let VsockConnectionState::Connected { writer, .. } = &conn.state else {
382 warn!(
383 "Received data packet for connection in unexpected state for {address:?}"
384 );
385 return Ok(());
386 };
387 payload_socket = writer.clone();
388 } else {
389 warn!("Received data packet for connection that didn't exist at {address:?}");
390 return Ok(());
391 }
392 let mut socket_guard =
393 ConnectionStateWriter::wait_available(Arc::clone(&payload_socket)).await;
394 let ConnectionStateWriter::Available(socket) = &mut *socket_guard else {
395 unreachable!("wait_available didn't wait until socket was available!");
396 };
397 match socket.write_all(payload) {
398 Err(err) => {
399 debug!(
400 "Write to socket address {address:?} failed, \
401 resetting connection immediately: {err:?}"
402 );
403 self.reset(&address)
404 .await
405 .inspect_err(|err| {
406 warn!(
407 "Attempt to reset connection to {address:?} \
408 failed after write error: {err:?}"
409 );
410 })
411 .ok();
412 }
413 Ok(status) => {
414 if status.overflowed() {
415 if self.protocol_version.has_pause_packets() {
416 let header = &mut Header::new(PacketType::Pause);
417 let payload = &PausePacket::Pause.bytes();
418 header.set_address(&address);
419 header.payload_len.set(payload.len() as u32);
420 self.packet_filler
421 .write_vsock_packet(&Packet { header, payload })
422 .await
423 .expect_right_size(
424 "pause packet should never be too large to fit in a usb packet",
425 )?;
426 }
427
428 let weak_payload_socket = Arc::downgrade(&payload_socket);
429 let connections = Arc::clone(&self.connections);
430 let has_pause_packets = self.protocol_version.has_pause_packets();
431 let packet_filler = Arc::clone(&self.packet_filler);
432 self.task_scope.spawn(async move {
433 let res = OverflowHandleFut::new(weak_payload_socket).await;
434
435 if let Err(err) = res {
436 debug!(
437 "Write to socket address {address:?} failed while \
438 processing backlog, resetting connection at next poll: {err:?}"
439 );
440 if let Err(err) = reset(&address, &connections, &packet_filler).await {
441 debug!("Error sending reset frame after overflow write failed: {err:?}");
442 }
443 } else if has_pause_packets {
444 let header = &mut Header::new(PacketType::Pause);
445 let payload = &PausePacket::UnPause.bytes();
446 header.set_address(&address);
447 header.payload_len.set(payload.len() as u32);
448 let _: Result<_, ShutdownError> =
449 packet_filler
450 .write_vsock_packet(&Packet { header, payload })
451 .await
452 .expect_right_size("pause packet should never be too large to fit in a usb packet");
453 }
454 });
455 }
456 }
457 }
458 Ok(())
459 }
460 }
461
462 async fn handle_echo_packet(&self, address: Address, payload: &[u8]) -> Result<(), Error> {
463 debug!("received echo for {address:?} with payload {payload:?}");
464 let header = &mut Header::new(PacketType::EchoReply);
465 header.payload_len.set(payload.len() as u32);
466 header.set_address(&address);
467 self.packet_filler.write_vsock_packet(&Packet { header, payload }).await.map_err(
468 |e| match e {
469 crate::WritePacketError::PacketTooBig(_) => {
470 Error::other("Echo packet was too large to be sent back")
471 }
472 crate::WritePacketError::Shutdown(shutdown_error) => shutdown_error.into(),
473 },
474 )
475 }
476
477 async fn handle_echo_reply_packet(
478 &self,
479 address: Address,
480 payload: &[u8],
481 ) -> Result<(), Error> {
482 debug!("received echo reply for {address:?} with payload {payload:?}");
484 Ok(())
485 }
486
487 async fn handle_accept_packet(&self, address: Address) -> Result<(), Error> {
488 if let Some(conn) = self.connections.lock().get_mut(&address) {
489 let state = std::mem::replace(&mut conn.state, VsockConnectionState::Invalid);
490 let VsockConnectionState::ConnectingOutgoing(connected_tx) = state else {
491 warn!("Received accept packet for connection in unexpected state for {address:?}");
492 return Ok(());
493 };
494 let (notify_closed, notify_closed_rx) = mpsc::channel(2);
495 if connected_tx.send(ConnectionState(notify_closed_rx)).is_err() {
496 warn!(
497 "Accept packet received for {address:?} but connect caller stopped waiting for it"
498 );
499 }
500 let pause_state = PauseState::new();
501
502 let reader_scope = Scope::new_with_name("connection-reader");
503 conn.state = VsockConnectionState::Connected {
504 writer: Arc::new(Mutex::new(ConnectionStateWriter::NotYetAvailable(Vec::new()))),
505 reader_scope,
506 notify_closed,
507 pause_state,
508 };
509 } else {
510 warn!("Got accept packet for connection that was not being made at {address:?}");
511 return Ok(());
512 }
513 Ok(())
514 }
515
516 async fn handle_connect_packet(&self, address: Address) -> Result<(), Error> {
517 trace!("received connect packet for {address:?}");
518 match self.connections.lock().entry(address.clone()) {
519 Entry::Vacant(entry) => {
520 debug!("valid connect request for {address:?}");
521 entry.insert(VsockConnection {
522 _address: address,
523 state: VsockConnectionState::ConnectingIncoming,
524 });
525 }
526 Entry::Occupied(_) => {
527 warn!(
528 "Received connect packet for already existing \
529 connection for address {address:?}. Ignoring"
530 );
531 return Ok(());
532 }
533 }
534
535 trace!("sending incoming connection request to client for {address:?}");
536 let connection_request = ConnectionRequest { address };
537 self.incoming_requests_tx
538 .clone()
539 .send(connection_request)
540 .await
541 .inspect(|_| trace!("sent incoming request for {address:?}"))
542 .map_err(|_| Error::other("Failed to send connection request"))
543 }
544
545 async fn handle_finish_packet(&self, address: Address) -> Result<(), Error> {
546 trace!("received finish packet for {address:?}");
547 let mut notify;
548 if let Some(conn) = self.connections.lock().remove(&address) {
549 let VsockConnectionState::Connected { notify_closed, .. } = conn.state else {
550 warn!(
551 "Received finish (close) packet for {address:?} \
552 which was not in a connected state. Ignoring and dropping connection state."
553 );
554 return Ok(());
555 };
556 notify = notify_closed;
557 } else {
558 warn!(
559 "Received finish (close) packet for connection that didn't exist \
560 on address {address:?}. Ignoring"
561 );
562 return Ok(());
563 }
564
565 notify.send(Ok(())).await.ok();
566
567 let header = &mut Header::new(PacketType::Reset);
568 header.set_address(&address);
569 self.packet_filler
570 .write_vsock_packet(&Packet { header, payload: &[] })
571 .await
572 .expect_right_size("accept packet should never be too large for packet buffer")?;
573 Ok(())
574 }
575
576 async fn handle_reset_packet(&self, address: Address) -> Result<(), Error> {
577 trace!("received reset packet for {address:?}");
578 let mut notify = None;
579 if let Some(conn) = self.connections.lock().remove(&address) {
580 if let VsockConnectionState::Connected { notify_closed, .. } = conn.state {
581 notify = Some(notify_closed);
582 } else {
583 debug!(
584 "Received reset packet for connection that wasn't in a connecting or \
585 disconnected state on address {address:?}."
586 );
587 }
588 } else {
589 warn!(
590 "Received reset packet for connection that didn't \
591 exist on address {address:?}. Ignoring"
592 );
593 }
594
595 if let Some(mut notify) = notify {
596 notify.send(Ok(())).await.ok();
597 }
598 Ok(())
599 }
600
601 async fn handle_pause_packet(&self, address: Address, payload: &[u8]) -> Result<(), Error> {
602 if !self.protocol_version.has_pause_packets() {
603 warn!(
604 "Got a pause packet while using protocol \
605 version {} which does not support them. Ignoring",
606 self.protocol_version
607 );
608 return Ok(());
609 }
610
611 let pause = match payload {
612 [1] => true,
613 [0] => false,
614 other => {
615 warn!("Ignoring unexpected pause packet payload {other:?}");
616 return Ok(());
617 }
618 };
619
620 if let Some(conn) = self.connections.lock().get(&address) {
621 if let VsockConnectionState::Connected { pause_state, .. } = &conn.state {
622 pause_state.set_paused(pause);
623 } else {
624 warn!("Received pause packet for unestablished connection. Ignoring");
625 };
626 } else {
627 warn!(
628 "Received pause packet for connection that didn't exist on address {address:?}. Ignoring"
629 );
630 }
631
632 Ok(())
633 }
634
635 pub async fn handle_vsock_packet(&self, packet: Packet<'_>) -> Result<(), Error> {
638 trace!("received vsock packet {header:?}", header = packet.header);
639 let payload_len = packet.header.payload_len.get() as usize;
640 let payload = &packet.payload[..payload_len];
641 let address = Address::from(packet.header);
642 match packet.header.packet_type {
643 PacketType::Sync => Err(Error::other("Received sync packet mid-stream")),
644 PacketType::Data => self.handle_data_packet(address, payload).await,
645 PacketType::Accept => self.handle_accept_packet(address).await,
646 PacketType::Connect => self.handle_connect_packet(address).await,
647 PacketType::Finish => self.handle_finish_packet(address).await,
648 PacketType::Reset => self.handle_reset_packet(address).await,
649 PacketType::Echo => self.handle_echo_packet(address, payload).await,
650 PacketType::EchoReply => self.handle_echo_reply_packet(address, payload).await,
651 PacketType::Pause => self.handle_pause_packet(address, payload).await,
652 }
653 }
654
655 pub async fn fill_usb_packet(
662 &self,
663 builder: UsbPacketBuilder<B>,
664 ) -> Result<UsbPacketBuilder<B>, ShutdownError> {
665 self.packet_filler.fill_usb_packet(builder).await
666 }
667}
668
669impl<B: PacketBuffer, S> Connection<B, S> {
670 pub fn shutdown(&self) {
673 self.packet_filler.shutdown();
674 self.connections.lock().clear();
675 }
676}
677
678async fn reset<B: PacketBuffer, S: AsyncRead + AsyncWrite + Send + 'static>(
679 address: &Address,
680 connections: &fuchsia_sync::Mutex<HashMap<Address, VsockConnection<S>>>,
681 packet_filler: &UsbPacketFiller<B>,
682) -> Result<(), Error> {
683 let mut notify = None;
684 if let Some(conn) = connections.lock().remove(&address) {
685 if let VsockConnectionState::Connected { notify_closed, .. } = conn.state {
686 notify = Some(notify_closed);
687 }
688 } else {
689 return Err(Error::other(
690 "Client asked to reset connection {address:?} that did not exist",
691 ));
692 }
693
694 if let Some(mut notify) = notify {
695 notify.send(Err(ErrorKind::ConnectionReset.into())).await.ok();
696 }
697
698 let header = &mut Header::new(PacketType::Reset);
699 header.set_address(address);
700 packet_filler
701 .write_vsock_packet(&Packet { header, payload: &[] })
702 .await
703 .expect_right_size("Reset packet should never be too big")?;
704 Ok(())
705}
706
707enum ConnectionStateWriter<S> {
712 NotYetAvailable(Vec<Waker>),
713 Available(OverflowWriter<S>),
714}
715
716impl<S> ConnectionStateWriter<S> {
717 fn wait_available(this: Arc<Mutex<ConnectionStateWriter<S>>>) -> ConnectionStateWriterFut<S> {
719 ConnectionStateWriterFut { writer: this, lock_fut: None }
720 }
721}
722
723struct ConnectionStateWriterFut<S> {
725 writer: Arc<Mutex<ConnectionStateWriter<S>>>,
726 lock_fut: Option<futures::lock::OwnedMutexLockFuture<ConnectionStateWriter<S>>>,
727}
728
729impl<S> Future for ConnectionStateWriterFut<S> {
730 type Output = OwnedMutexGuard<ConnectionStateWriter<S>>;
731
732 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
733 let writer = Arc::clone(&self.writer);
734 let lock_fut = self.lock_fut.get_or_insert_with(|| writer.lock_owned());
735 let mut lock = ready!(lock_fut.poll_unpin(cx));
736 self.lock_fut = None;
737 match &mut *lock {
738 ConnectionStateWriter::Available(_) => Poll::Ready(lock),
739 ConnectionStateWriter::NotYetAvailable(queue) => {
740 queue.push(cx.waker().clone());
741 Poll::Pending
742 }
743 }
744 }
745}
746
747enum VsockConnectionState<S> {
748 ConnectingOutgoing(oneshot::Sender<ConnectionState>),
749 ConnectingIncoming,
750 Connected {
751 writer: Arc<Mutex<ConnectionStateWriter<S>>>,
752 notify_closed: mpsc::Sender<Result<(), Error>>,
753 pause_state: Arc<PauseState>,
754 reader_scope: Scope,
755 },
756 Invalid,
757}
758
759struct VsockConnection<S> {
760 _address: Address,
761 state: VsockConnectionState<S>,
762}
763
764#[derive(Debug)]
768pub struct ConnectionState(mpsc::Receiver<Result<(), Error>>);
769
770impl ConnectionState {
771 pub async fn wait_for_close(mut self) -> Result<(), Error> {
774 self.0
775 .next()
776 .await
777 .ok_or_else(|| Error::other("Connection state's other end was dropped"))?
778 }
779}
780
781#[derive(Debug)]
784pub struct ConnectionRequest {
785 address: Address,
786}
787
788impl ConnectionRequest {
789 pub fn new(address: Address) -> Self {
791 Self { address }
792 }
793
794 pub fn address(&self) -> &Address {
796 &self.address
797 }
798}
799
800#[cfg(test)]
801mod test {
802 use std::sync::Arc;
803 use test_case::test_case;
804
805 use crate::VsockPacketIterator;
806
807 use super::*;
808
809 #[cfg(not(target_os = "fuchsia"))]
810 use fuchsia_async::emulated_handle::Socket as SyncSocket;
811 use fuchsia_async::{Socket, Task};
812 use futures::StreamExt;
813 #[cfg(target_os = "fuchsia")]
814 use zx::Socket as SyncSocket;
815
816 async fn usb_echo_server(echo_connection: Arc<Connection<Vec<u8>, Socket>>) {
817 let mut builder = UsbPacketBuilder::new(vec![0; 128]);
818 loop {
819 println!("waiting for usb packet");
820 builder = echo_connection.fill_usb_packet(builder).await.unwrap();
821 let packets = VsockPacketIterator::new(builder.take_usb_packet().unwrap());
822 println!("got usb packet, echoing it back to the other side");
823 let mut packet_count = 0;
824 for packet in packets {
825 let packet = packet.unwrap();
826 match packet.header.packet_type {
827 PacketType::Connect => {
828 let mut reply_header = packet.header.clone();
830 reply_header.packet_type = PacketType::Accept;
831 echo_connection
832 .handle_vsock_packet(Packet { header: &reply_header, payload: &[] })
833 .await
834 .unwrap();
835 }
836 PacketType::Accept => {
837 }
839 _ => echo_connection.handle_vsock_packet(packet).await.unwrap(),
840 }
841 packet_count += 1;
842 }
843 println!("handled {packet_count} packets");
844 }
845 }
846
847 #[fuchsia::test]
848 async fn data_over_control_socket() {
849 let (socket, other_socket) = SyncSocket::create_stream();
850 let (incoming_requests_tx, _incoming_requests) = mpsc::channel(5);
851 let mut socket = Socket::from_socket(socket);
852 let connection = Arc::new(Connection::new(
853 ProtocolVersion::LATEST,
854 Some(Socket::from_socket(other_socket)),
855 incoming_requests_tx,
856 ));
857
858 let echo_task = Task::spawn(usb_echo_server(connection.clone()));
859
860 for size in [1u8, 2, 8, 16, 32, 64, 128, 255] {
861 println!("round tripping packet of size {size}");
862 socket.write_all(&vec![size; size as usize]).await.unwrap();
863 let mut buf = vec![0u8; size as usize];
864 socket.read_exact(&mut buf).await.unwrap();
865 assert_eq!(buf, vec![size; size as usize]);
866 }
867 echo_task.abort().await;
868 }
869
870 #[fuchsia::test]
871 async fn data_over_normal_outgoing_socket() {
872 let (_control_socket, other_socket) = SyncSocket::create_stream();
873 let (incoming_requests_tx, _incoming_requests) = mpsc::channel(5);
874 let connection = Arc::new(Connection::new(
875 ProtocolVersion::LATEST,
876 Some(Socket::from_socket(other_socket)),
877 incoming_requests_tx,
878 ));
879
880 let echo_task = Task::spawn(usb_echo_server(connection.clone()));
881
882 let (socket, other_socket) = SyncSocket::create_stream();
883 let mut socket = Socket::from_socket(socket);
884 connection
885 .connect(
886 Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 },
887 Socket::from_socket(other_socket),
888 )
889 .await
890 .unwrap();
891
892 for size in [1u8, 2, 8, 16, 32, 64, 128, 255] {
893 println!("round tripping packet of size {size}");
894 socket.write_all(&vec![size; size as usize]).await.unwrap();
895 let mut buf = vec![0u8; size as usize];
896 socket.read_exact(&mut buf).await.unwrap();
897 assert_eq!(buf, vec![size; size as usize]);
898 }
899 echo_task.abort().await;
900 }
901
902 #[fuchsia::test]
903 async fn data_over_normal_incoming_socket() {
904 let (_control_socket, other_socket) = SyncSocket::create_stream();
905 let (incoming_requests_tx, mut incoming_requests) = mpsc::channel(5);
906 let connection = Arc::new(Connection::new(
907 ProtocolVersion::LATEST,
908 Some(Socket::from_socket(other_socket)),
909 incoming_requests_tx,
910 ));
911
912 let echo_task = Task::spawn(usb_echo_server(connection.clone()));
913
914 let header = &mut Header::new(PacketType::Connect);
915 header.set_address(&Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 });
916 connection.handle_vsock_packet(Packet { header, payload: &[] }).await.unwrap();
917
918 let request = incoming_requests.next().await.unwrap();
919 assert_eq!(
920 request.address,
921 Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 }
922 );
923
924 let (socket, other_socket) = SyncSocket::create_stream();
925 let mut socket = Socket::from_socket(socket);
926 connection.accept(request, Socket::from_socket(other_socket)).await.unwrap();
927
928 for size in [1u8, 2, 8, 16, 32, 64, 128, 255] {
929 println!("round tripping packet of size {size}");
930 socket.write_all(&vec![size; size as usize]).await.unwrap();
931 let mut buf = vec![0u8; size as usize];
932 socket.read_exact(&mut buf).await.unwrap();
933 assert_eq!(buf, vec![size; size as usize]);
934 }
935 echo_task.abort().await;
936 }
937
938 async fn copy_connection(from: &Connection<Vec<u8>, Socket>, to: &Connection<Vec<u8>, Socket>) {
939 let mut builder = UsbPacketBuilder::new(vec![0; 1024]);
940 loop {
941 builder = from.fill_usb_packet(builder).await.unwrap();
942 let packets = VsockPacketIterator::new(builder.take_usb_packet().unwrap());
943 for packet in packets {
944 println!("forwarding vsock packet");
945 to.handle_vsock_packet(packet.unwrap()).await.unwrap();
946 }
947 }
948 }
949
950 pub(crate) trait EndToEndTestFn<R>:
951 AsyncFnOnce(Arc<Connection<Vec<u8>, Socket>>, mpsc::Receiver<ConnectionRequest>) -> R
952 {
953 }
954 impl<T, R> EndToEndTestFn<R> for T where
955 T: AsyncFnOnce(Arc<Connection<Vec<u8>, Socket>>, mpsc::Receiver<ConnectionRequest>) -> R
956 {
957 }
958
959 pub(crate) async fn end_to_end_test<R1, R2>(
960 left_side: impl EndToEndTestFn<R1>,
961 right_side: impl EndToEndTestFn<R2>,
962 ) -> (R1, R2) {
963 type Connection = crate::Connection<Vec<u8>, Socket>;
964 let (_control_socket1, other_socket1) = SyncSocket::create_stream();
965 let (_control_socket2, other_socket2) = SyncSocket::create_stream();
966 let (incoming_requests_tx1, incoming_requests1) = mpsc::channel(5);
967 let (incoming_requests_tx2, incoming_requests2) = mpsc::channel(5);
968
969 let connection1 = Arc::new(Connection::new(
970 ProtocolVersion::LATEST,
971 Some(Socket::from_socket(other_socket1)),
972 incoming_requests_tx1,
973 ));
974 let connection2 = Arc::new(Connection::new(
975 ProtocolVersion::LATEST,
976 Some(Socket::from_socket(other_socket2)),
977 incoming_requests_tx2,
978 ));
979
980 let conn1 = connection1.clone();
981 let conn2 = connection2.clone();
982 let passthrough_task = Task::spawn(async move {
983 futures::join!(copy_connection(&conn1, &conn2), copy_connection(&conn2, &conn1),);
984 println!("passthrough task loop ended");
985 });
986
987 let res = futures::join!(
988 left_side(connection1, incoming_requests1),
989 right_side(connection2, incoming_requests2)
990 );
991 passthrough_task.abort().await;
992 res
993 }
994
995 #[fuchsia::test]
996 async fn data_over_end_to_end() {
997 end_to_end_test(
998 async |conn, _incoming| {
999 println!("sending request on connection 1");
1000 let (socket, other_socket) = SyncSocket::create_stream();
1001 let mut socket = Socket::from_socket(socket);
1002 let state = conn
1003 .connect(
1004 Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 },
1005 Socket::from_socket(other_socket),
1006 )
1007 .await
1008 .unwrap();
1009
1010 for size in [1u8, 2, 8, 16, 32, 64, 128, 255] {
1011 println!("round tripping packet of size {size}");
1012 socket.write_all(&vec![size; size as usize]).await.unwrap();
1013 }
1014 drop(socket);
1015 state.wait_for_close().await.unwrap();
1016 },
1017 async |conn, mut incoming| {
1018 println!("accepting request on connection 2");
1019 let request = incoming.next().await.unwrap();
1020 assert_eq!(
1021 request.address,
1022 Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 }
1023 );
1024
1025 let (socket, other_socket) = SyncSocket::create_stream();
1026 let mut socket = Socket::from_socket(socket);
1027 let state = conn.accept(request, Socket::from_socket(other_socket)).await.unwrap();
1028
1029 println!("accepted request on connection 2");
1030 for size in [1u8, 2, 8, 16, 32, 64, 128, 255] {
1031 let mut buf = vec![0u8; size as usize];
1032 socket.read_exact(&mut buf).await.unwrap();
1033 assert_eq!(buf, vec![size; size as usize]);
1034 }
1035 assert_eq!(socket.read(&mut [0u8; 1]).await.unwrap(), 0);
1036 state.wait_for_close().await.unwrap();
1037 },
1038 )
1039 .await;
1040 }
1041
1042 #[fuchsia::test]
1043 async fn normal_close_end_to_end() {
1044 let addr = Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 };
1045 end_to_end_test(
1046 async |conn, _incoming| {
1047 let (socket, other_socket) = SyncSocket::create_stream();
1048 let mut socket = Socket::from_socket(socket);
1049 let state =
1050 conn.connect(addr.clone(), Socket::from_socket(other_socket)).await.unwrap();
1051 conn.close(&addr).await;
1052 assert_eq!(socket.read(&mut [0u8; 1]).await.unwrap(), 0);
1053 state.wait_for_close().await.unwrap();
1054 },
1055 async |conn, mut incoming| {
1056 println!("accepting request on connection 2");
1057 let request = incoming.next().await.unwrap();
1058 assert_eq!(request.address, addr.clone(),);
1059
1060 let (socket, other_socket) = SyncSocket::create_stream();
1061 let mut socket = Socket::from_socket(socket);
1062 let state = conn.accept(request, Socket::from_socket(other_socket)).await.unwrap();
1063 assert_eq!(socket.read(&mut [0u8; 1]).await.unwrap(), 0);
1064 state.wait_for_close().await.unwrap();
1065 },
1066 )
1067 .await;
1068 }
1069
1070 #[fuchsia::test]
1071 async fn reset_end_to_end() {
1072 let addr = Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 };
1073 end_to_end_test(
1074 async |conn, _incoming| {
1075 let (socket, other_socket) = SyncSocket::create_stream();
1076 let mut socket = Socket::from_socket(socket);
1077 let state =
1078 conn.connect(addr.clone(), Socket::from_socket(other_socket)).await.unwrap();
1079 conn.reset(&addr).await.unwrap();
1080 assert_eq!(socket.read(&mut [0u8; 1]).await.unwrap(), 0);
1081 state.wait_for_close().await.expect_err("expected reset");
1082 },
1083 async |conn, mut incoming| {
1084 println!("accepting request on connection 2");
1085 let request = incoming.next().await.unwrap();
1086 assert_eq!(request.address, addr.clone(),);
1087
1088 let (socket, other_socket) = SyncSocket::create_stream();
1089 let mut socket = Socket::from_socket(socket);
1090 let state = conn.accept(request, Socket::from_socket(other_socket)).await.unwrap();
1091 assert_eq!(socket.read(&mut [0u8; 1]).await.unwrap(), 0);
1092 state.wait_for_close().await.unwrap();
1093 },
1094 )
1095 .await;
1096 }
1097
1098 #[test_case(false; "in packet handling")]
1099 #[test_case(true; "in reply wait")]
1100 #[fuchsia::test]
1101 async fn conn_shutdown(fill_packets: bool) {
1102 let (incoming_requests_tx, _incoming_requests) = mpsc::channel(5);
1103
1104 let connection = Arc::new(Connection::<Vec<u8>, fuchsia_async::Socket>::new(
1105 ProtocolVersion::LATEST,
1106 None,
1107 incoming_requests_tx,
1108 ));
1109
1110 let mut filler = if fill_packets {
1111 Some(std::pin::pin!(connection.fill_usb_packet(UsbPacketBuilder::new(Vec::new()))))
1112 } else {
1113 None
1114 };
1115
1116 let addr = Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 };
1117 let mut fut = std::pin::pin!(connection.connect_late(addr));
1118
1119 for _ in 0..5 {
1120 assert!(fut.as_mut().poll(&mut Context::from_waker(Waker::noop())).is_pending());
1121 if let Some(filler) = filler.as_mut() {
1122 assert!(filler.as_mut().poll(&mut Context::from_waker(Waker::noop())).is_pending())
1123 }
1124 }
1125
1126 connection.shutdown();
1127 let Poll::Ready(res) = fut.poll(&mut Context::from_waker(Waker::noop())) else { panic!() };
1128 assert!(res.is_err());
1129 if let Some(filler) = filler {
1130 let Poll::Ready(res) = filler.poll(&mut Context::from_waker(Waker::noop())) else {
1131 panic!()
1132 };
1133 assert!(res.is_err());
1134 }
1135 }
1136}