1use fuchsia_async::{DurationExt, OnTimeout, TimeoutExt};
6use fuchsia_bluetooth::types::Channel;
7use fuchsia_sync::Mutex;
8use futures::future::{FusedFuture, MaybeDone};
9use futures::stream::Stream;
10use futures::task::{Context, Poll, Waker};
11use futures::{ready, Future, FutureExt, TryFutureExt};
12use log::{info, trace, warn};
13use packet_encoding::{Decodable, Encodable};
14use slab::Slab;
15use std::collections::VecDeque;
16use std::marker::PhantomData;
17use std::mem;
18use std::pin::Pin;
19use std::sync::Arc;
20use zx::{self as zx, MonotonicDuration};
21
22#[cfg(test)]
23mod tests;
24
25mod rtp;
26mod stream_endpoint;
27mod types;
28
29use crate::types::{SignalIdentifier, SignalingHeader, SignalingMessageType, TxLabel};
30
31pub use crate::rtp::{RtpError, RtpHeader};
32pub use crate::stream_endpoint::{
33 MediaStream, StreamEndpoint, StreamEndpointUpdateCallback, StreamState,
34};
35pub use crate::types::{
36 ContentProtectionType, EndpointType, Error, ErrorCode, MediaCodecType, MediaType, RemoteReject,
37 Result, ServiceCapability, ServiceCategory, StreamEndpointId, StreamInformation,
38};
39
40#[derive(Debug, Clone)]
50pub struct Peer {
51 inner: Arc<PeerInner>,
52}
53
54impl Peer {
55 pub fn new(signaling: Channel) -> Self {
57 Self {
58 inner: Arc::new(PeerInner {
59 signaling,
60 response_waiters: Mutex::new(Slab::<ResponseWaiter>::new()),
61 incoming_requests: Mutex::<RequestQueue>::default(),
62 }),
63 }
64 }
65
66 #[track_caller]
69 pub fn take_request_stream(&self) -> RequestStream {
70 {
71 let mut lock = self.inner.incoming_requests.lock();
72 if let RequestListener::None = lock.listener {
73 lock.listener = RequestListener::New;
74 } else {
75 panic!("Request stream has already been taken");
76 }
77 }
78
79 RequestStream { inner: self.inner.clone() }
80 }
81
82 pub fn discover(&self) -> impl Future<Output = Result<Vec<StreamInformation>>> {
86 self.send_command::<DiscoverResponse>(SignalIdentifier::Discover, &[]).ok_into()
87 }
88
89 pub fn get_capabilities(
96 &self,
97 stream_id: &StreamEndpointId,
98 ) -> impl Future<Output = Result<Vec<ServiceCapability>>> {
99 let stream_params = &[stream_id.to_msg()];
100 self.send_command::<GetCapabilitiesResponse>(
101 SignalIdentifier::GetCapabilities,
102 stream_params,
103 )
104 .ok_into()
105 }
106
107 pub fn get_all_capabilities(
113 &self,
114 stream_id: &StreamEndpointId,
115 ) -> impl Future<Output = Result<Vec<ServiceCapability>>> {
116 let stream_params = &[stream_id.to_msg()];
117 self.send_command::<GetCapabilitiesResponse>(
118 SignalIdentifier::GetAllCapabilities,
119 stream_params,
120 )
121 .ok_into()
122 }
123
124 pub fn set_configuration(
131 &self,
132 stream_id: &StreamEndpointId,
133 local_stream_id: &StreamEndpointId,
134 capabilities: &[ServiceCapability],
135 ) -> impl Future<Output = Result<()>> {
136 assert!(!capabilities.is_empty(), "must set at least one capability");
137 let mut params: Vec<u8> = vec![0; capabilities.iter().fold(2, |a, x| a + x.encoded_len())];
138 params[0] = stream_id.to_msg();
139 params[1] = local_stream_id.to_msg();
140 let mut idx = 2;
141 for capability in capabilities {
142 if let Err(e) = capability.encode(&mut params[idx..]) {
143 return futures::future::err(e).left_future();
144 }
145 idx += capability.encoded_len();
146 }
147 self.send_command::<SimpleResponse>(SignalIdentifier::SetConfiguration, ¶ms)
148 .ok_into()
149 .right_future()
150 }
151
152 pub fn get_configuration(
158 &self,
159 stream_id: &StreamEndpointId,
160 ) -> impl Future<Output = Result<Vec<ServiceCapability>>> {
161 let stream_params = &[stream_id.to_msg()];
162 self.send_command::<GetCapabilitiesResponse>(
163 SignalIdentifier::GetConfiguration,
164 stream_params,
165 )
166 .ok_into()
167 }
168
169 pub fn reconfigure(
178 &self,
179 stream_id: &StreamEndpointId,
180 capabilities: &[ServiceCapability],
181 ) -> impl Future<Output = Result<()>> {
182 assert!(!capabilities.is_empty(), "must set at least one capability");
183 let mut params: Vec<u8> = vec![0; capabilities.iter().fold(1, |a, x| a + x.encoded_len())];
184 params[0] = stream_id.to_msg();
185 let mut idx = 1;
186 for capability in capabilities {
187 if !capability.is_application() {
188 return futures::future::err(Error::Encoding).left_future();
189 }
190 if let Err(e) = capability.encode(&mut params[idx..]) {
191 return futures::future::err(e).left_future();
192 }
193 idx += capability.encoded_len();
194 }
195 self.send_command::<SimpleResponse>(SignalIdentifier::Reconfigure, ¶ms)
196 .ok_into()
197 .right_future()
198 }
199
200 pub fn open(&self, stream_id: &StreamEndpointId) -> impl Future<Output = Result<()>> {
204 let stream_params = &[stream_id.to_msg()];
205 self.send_command::<SimpleResponse>(SignalIdentifier::Open, stream_params).ok_into()
206 }
207
208 pub fn start(&self, stream_ids: &[StreamEndpointId]) -> impl Future<Output = Result<()>> {
213 let mut stream_params = Vec::with_capacity(stream_ids.len());
214 for stream_id in stream_ids {
215 stream_params.push(stream_id.to_msg());
216 }
217 self.send_command::<SimpleResponse>(SignalIdentifier::Start, &stream_params).ok_into()
218 }
219
220 pub fn close(&self, stream_id: &StreamEndpointId) -> impl Future<Output = Result<()>> {
223 let stream_params = &[stream_id.to_msg()];
224 let response: CommandResponseFut<SimpleResponse> =
225 self.send_command::<SimpleResponse>(SignalIdentifier::Close, stream_params);
226 response.ok_into()
227 }
228
229 pub fn suspend(&self, stream_ids: &[StreamEndpointId]) -> impl Future<Output = Result<()>> {
233 let mut stream_params = Vec::with_capacity(stream_ids.len());
234 for stream_id in stream_ids {
235 stream_params.push(stream_id.to_msg());
236 }
237 let response: CommandResponseFut<SimpleResponse> =
238 self.send_command::<SimpleResponse>(SignalIdentifier::Suspend, &stream_params);
239 response.ok_into()
240 }
241
242 pub fn abort(&self, stream_id: &StreamEndpointId) -> impl Future<Output = Result<()>> {
247 let stream_params = &[stream_id.to_msg()];
248 self.send_command::<SimpleResponse>(SignalIdentifier::Abort, stream_params).ok_into()
249 }
250
251 pub fn delay_report(
255 &self,
256 stream_id: &StreamEndpointId,
257 delay: u16,
258 ) -> impl Future<Output = Result<()>> {
259 let delay_bytes: [u8; 2] = delay.to_be_bytes();
260 let params = &[stream_id.to_msg(), delay_bytes[0], delay_bytes[1]];
261 self.send_command::<SimpleResponse>(SignalIdentifier::DelayReport, params).ok_into()
262 }
263
264 const RTX_SIG_TIMER_MS: i64 = 3000;
266 const COMMAND_TIMEOUT: MonotonicDuration =
267 MonotonicDuration::from_millis(Peer::RTX_SIG_TIMER_MS);
268
269 fn send_command<D: Decodable<Error = Error>>(
272 &self,
273 signal: SignalIdentifier,
274 payload: &[u8],
275 ) -> CommandResponseFut<D> {
276 let send_result = (|| {
277 let id = self.inner.add_response_waiter()?;
278 let header = SignalingHeader::new(id, signal, SignalingMessageType::Command);
279 let mut buf = vec![0; header.encoded_len()];
280 header.encode(buf.as_mut_slice())?;
281 buf.extend_from_slice(payload);
282 self.inner.send_signal(buf.as_slice())?;
283 Ok(header)
284 })();
285
286 CommandResponseFut::new(send_result, self.inner.clone())
287 }
288}
289
290struct CommandResponseFut<D: Decodable> {
292 id: SignalIdentifier,
293 fut: Pin<Box<MaybeDone<OnTimeout<CommandResponse, fn() -> Result<Vec<u8>>>>>>,
294 _phantom: PhantomData<D>,
295}
296
297impl<D: Decodable> Unpin for CommandResponseFut<D> {}
298
299impl<D: Decodable<Error = Error>> CommandResponseFut<D> {
300 fn new(send_result: Result<SignalingHeader>, inner: Arc<PeerInner>) -> Self {
301 let header = match send_result {
302 Err(e) => {
303 return Self {
304 id: SignalIdentifier::Abort,
305 fut: Box::pin(MaybeDone::Done(Err(e))),
306 _phantom: PhantomData,
307 }
308 }
309 Ok(header) => header,
310 };
311 let response = CommandResponse { id: header.label(), inner: Some(inner) };
312 let err_timeout: fn() -> Result<Vec<u8>> = || Err(Error::Timeout);
313 let timedout_fut = response.on_timeout(Peer::COMMAND_TIMEOUT.after_now(), err_timeout);
314
315 Self {
316 id: header.signal(),
317 fut: Box::pin(futures::future::maybe_done(timedout_fut)),
318 _phantom: PhantomData,
319 }
320 }
321}
322
323impl<D: Decodable<Error = Error>> Future for CommandResponseFut<D> {
324 type Output = Result<D>;
325
326 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
327 ready!(self.fut.poll_unpin(cx));
328 Poll::Ready(
329 self.fut
330 .as_mut()
331 .take_output()
332 .unwrap_or(Err(Error::AlreadyReceived))
333 .and_then(|buf| decode_signaling_response(self.id, buf)),
334 )
335 }
336}
337
338#[derive(Debug)]
343pub enum Request {
344 Discover {
345 responder: DiscoverResponder,
346 },
347 GetCapabilities {
348 stream_id: StreamEndpointId,
349 responder: GetCapabilitiesResponder,
350 },
351 GetAllCapabilities {
352 stream_id: StreamEndpointId,
353 responder: GetCapabilitiesResponder,
354 },
355 SetConfiguration {
356 local_stream_id: StreamEndpointId,
357 remote_stream_id: StreamEndpointId,
358 capabilities: Vec<ServiceCapability>,
359 responder: ConfigureResponder,
360 },
361 GetConfiguration {
362 stream_id: StreamEndpointId,
363 responder: GetCapabilitiesResponder,
364 },
365 Reconfigure {
366 local_stream_id: StreamEndpointId,
367 capabilities: Vec<ServiceCapability>,
368 responder: ConfigureResponder,
369 },
370 Open {
371 stream_id: StreamEndpointId,
372 responder: SimpleResponder,
373 },
374 Start {
375 stream_ids: Vec<StreamEndpointId>,
376 responder: StreamResponder,
377 },
378 Close {
379 stream_id: StreamEndpointId,
380 responder: SimpleResponder,
381 },
382 Suspend {
383 stream_ids: Vec<StreamEndpointId>,
384 responder: StreamResponder,
385 },
386 Abort {
387 stream_id: StreamEndpointId,
388 responder: SimpleResponder,
389 },
390 DelayReport {
391 stream_id: StreamEndpointId,
392 delay: u16,
393 responder: SimpleResponder,
394 }, }
396
397macro_rules! parse_one_seid {
398 ($body:ident, $signal:ident, $peer:ident, $id:ident, $request_variant:ident, $responder_type:ident) => {
399 if $body.len() != 1 {
400 Err(Error::RequestInvalid(ErrorCode::BadLength))
401 } else {
402 Ok(Request::$request_variant {
403 stream_id: StreamEndpointId::from_msg(&$body[0]),
404 responder: $responder_type { signal: $signal, peer: $peer, id: $id },
405 })
406 }
407 };
408}
409
410impl Request {
411 fn get_req_seids(body: &[u8]) -> Result<Vec<StreamEndpointId>> {
412 if body.len() < 1 {
413 return Err(Error::RequestInvalid(ErrorCode::BadLength));
414 }
415 Ok(body.iter().map(&StreamEndpointId::from_msg).collect())
416 }
417
418 fn get_req_capabilities(encoded: &[u8]) -> Result<Vec<ServiceCapability>> {
419 if encoded.len() < 2 {
420 return Err(Error::RequestInvalid(ErrorCode::BadLength));
421 }
422 let mut caps = vec![];
423 let mut loc = 0;
424 while loc < encoded.len() {
425 let cap = match ServiceCapability::decode(&encoded[loc..]) {
426 Ok(cap) => cap,
427 Err(Error::RequestInvalid(code)) => {
428 return Err(Error::RequestInvalidExtra(code, encoded[loc]));
429 }
430 Err(e) => return Err(e),
431 };
432 loc += cap.encoded_len();
433 caps.push(cap);
434 }
435 Ok(caps)
436 }
437
438 fn parse(
439 peer: Arc<PeerInner>,
440 id: TxLabel,
441 signal: SignalIdentifier,
442 body: &[u8],
443 ) -> Result<Request> {
444 match signal {
445 SignalIdentifier::Discover => {
446 if body.len() > 0 {
448 return Err(Error::RequestInvalid(ErrorCode::BadLength));
449 }
450 Ok(Request::Discover { responder: DiscoverResponder { peer, id } })
451 }
452 SignalIdentifier::GetCapabilities => {
453 parse_one_seid!(body, signal, peer, id, GetCapabilities, GetCapabilitiesResponder)
454 }
455 SignalIdentifier::GetAllCapabilities => parse_one_seid!(
456 body,
457 signal,
458 peer,
459 id,
460 GetAllCapabilities,
461 GetCapabilitiesResponder
462 ),
463 SignalIdentifier::SetConfiguration => {
464 if body.len() < 4 {
465 return Err(Error::RequestInvalid(ErrorCode::BadLength));
466 }
467 let requested = Request::get_req_capabilities(&body[2..])?;
468 Ok(Request::SetConfiguration {
469 local_stream_id: StreamEndpointId::from_msg(&body[0]),
470 remote_stream_id: StreamEndpointId::from_msg(&body[1]),
471 capabilities: requested,
472 responder: ConfigureResponder { signal, peer, id },
473 })
474 }
475 SignalIdentifier::GetConfiguration => {
476 parse_one_seid!(body, signal, peer, id, GetConfiguration, GetCapabilitiesResponder)
477 }
478 SignalIdentifier::Reconfigure => {
479 if body.len() < 3 {
480 return Err(Error::RequestInvalid(ErrorCode::BadLength));
481 }
482 let requested = Request::get_req_capabilities(&body[1..])?;
483 match requested.iter().find(|x| !x.is_application()) {
484 Some(x) => {
485 return Err(Error::RequestInvalidExtra(
486 ErrorCode::InvalidCapabilities,
487 (&x.category()).into(),
488 ));
489 }
490 None => (),
491 };
492 Ok(Request::Reconfigure {
493 local_stream_id: StreamEndpointId::from_msg(&body[0]),
494 capabilities: requested,
495 responder: ConfigureResponder { signal, peer, id },
496 })
497 }
498 SignalIdentifier::Open => {
499 parse_one_seid!(body, signal, peer, id, Open, SimpleResponder)
500 }
501 SignalIdentifier::Start => {
502 let seids = Request::get_req_seids(body)?;
503 Ok(Request::Start {
504 stream_ids: seids,
505 responder: StreamResponder { signal, peer, id },
506 })
507 }
508 SignalIdentifier::Close => {
509 parse_one_seid!(body, signal, peer, id, Close, SimpleResponder)
510 }
511 SignalIdentifier::Suspend => {
512 let seids = Request::get_req_seids(body)?;
513 Ok(Request::Suspend {
514 stream_ids: seids,
515 responder: StreamResponder { signal, peer, id },
516 })
517 }
518 SignalIdentifier::Abort => {
519 parse_one_seid!(body, signal, peer, id, Abort, SimpleResponder)
520 }
521 SignalIdentifier::DelayReport => {
522 if body.len() != 3 {
523 return Err(Error::RequestInvalid(ErrorCode::BadLength));
524 }
525 let delay_arr: [u8; 2] = [body[1], body[2]];
526 let delay = u16::from_be_bytes(delay_arr);
527 Ok(Request::DelayReport {
528 stream_id: StreamEndpointId::from_msg(&body[0]),
529 delay,
530 responder: SimpleResponder { signal, peer, id },
531 })
532 }
533 _ => Err(Error::UnimplementedMessage),
534 }
535 }
536}
537
538#[derive(Debug)]
540pub struct RequestStream {
541 inner: Arc<PeerInner>,
542}
543
544impl Unpin for RequestStream {}
545
546impl Stream for RequestStream {
547 type Item = Result<Request>;
548
549 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
550 Poll::Ready(match ready!(self.inner.poll_recv_request(cx)) {
551 Ok(UnparsedRequest(SignalingHeader { label, signal, .. }, body)) => {
552 match Request::parse(self.inner.clone(), label, signal, &body) {
553 Err(Error::RequestInvalid(code)) => {
554 self.inner.send_reject(label, signal, code)?;
555 return Poll::Pending;
556 }
557 Err(Error::RequestInvalidExtra(code, extra)) => {
558 self.inner.send_reject_params(label, signal, &[extra, u8::from(&code)])?;
559 return Poll::Pending;
560 }
561 Err(Error::UnimplementedMessage) => {
562 self.inner.send_reject(label, signal, ErrorCode::NotSupportedCommand)?;
563 return Poll::Pending;
564 }
565 x => Some(x),
566 }
567 }
568 Err(Error::PeerDisconnected) => None,
569 Err(e) => Some(Err(e)),
570 })
571 }
572}
573
574impl Drop for RequestStream {
575 fn drop(&mut self) {
576 self.inner.incoming_requests.lock().listener = RequestListener::None;
577 self.inner.wake_any();
578 }
579}
580
581#[derive(Debug)]
583pub struct SimpleResponse {}
584
585impl Decodable for SimpleResponse {
586 type Error = Error;
587
588 fn decode(from: &[u8]) -> Result<Self> {
589 if from.len() > 0 {
590 return Err(Error::InvalidMessage);
591 }
592 Ok(SimpleResponse {})
593 }
594}
595
596impl Into<()> for SimpleResponse {
597 fn into(self) -> () {
598 ()
599 }
600}
601
602#[derive(Debug)]
603struct DiscoverResponse {
604 endpoints: Vec<StreamInformation>,
605}
606
607impl Decodable for DiscoverResponse {
608 type Error = Error;
609
610 fn decode(from: &[u8]) -> Result<Self> {
611 let mut endpoints = Vec::<StreamInformation>::new();
612 let mut idx = 0;
613 while idx < from.len() {
614 let endpoint = StreamInformation::decode(&from[idx..])?;
615 idx += endpoint.encoded_len();
616 endpoints.push(endpoint);
617 }
618 Ok(DiscoverResponse { endpoints })
619 }
620}
621
622impl Into<Vec<StreamInformation>> for DiscoverResponse {
623 fn into(self) -> Vec<StreamInformation> {
624 self.endpoints
625 }
626}
627
628#[derive(Debug)]
629pub struct DiscoverResponder {
630 peer: Arc<PeerInner>,
631 id: TxLabel,
632}
633
634impl DiscoverResponder {
635 pub fn send(self, endpoints: &[StreamInformation]) -> Result<()> {
639 if endpoints.len() == 0 {
640 return Err(Error::Encoding);
642 }
643 let mut params = vec![0 as u8; endpoints.len() * endpoints[0].encoded_len()];
644 let mut idx = 0;
645 for endpoint in endpoints {
646 endpoint.encode(&mut params[idx..idx + endpoint.encoded_len()])?;
647 idx += endpoint.encoded_len();
648 }
649 self.peer.send_response(self.id, SignalIdentifier::Discover, ¶ms)
650 }
651
652 pub fn reject(self, error_code: ErrorCode) -> Result<()> {
653 self.peer.send_reject(self.id, SignalIdentifier::Discover, error_code)
654 }
655}
656
657#[derive(Debug)]
658pub struct GetCapabilitiesResponder {
659 peer: Arc<PeerInner>,
660 signal: SignalIdentifier,
661 id: TxLabel,
662}
663
664impl GetCapabilitiesResponder {
665 pub fn send(self, capabilities: &[ServiceCapability]) -> Result<()> {
666 let included_iter = capabilities.iter().filter(|x| x.in_response(self.signal));
667 let reply_len = included_iter.clone().fold(0, |a, b| a + b.encoded_len());
668 let mut reply = vec![0 as u8; reply_len];
669 let mut pos = 0;
670 for capability in included_iter {
671 let size = capability.encoded_len();
672 capability.encode(&mut reply[pos..pos + size])?;
673 pos += size;
674 }
675 self.peer.send_response(self.id, self.signal, &reply)
676 }
677
678 pub fn reject(self, error_code: ErrorCode) -> Result<()> {
679 self.peer.send_reject(self.id, self.signal, error_code)
680 }
681}
682
683#[derive(Debug)]
684struct GetCapabilitiesResponse {
685 capabilities: Vec<ServiceCapability>,
686}
687
688impl Decodable for GetCapabilitiesResponse {
689 type Error = Error;
690
691 fn decode(from: &[u8]) -> Result<Self> {
692 let mut capabilities = Vec::<ServiceCapability>::new();
693 let mut idx = 0;
694 while idx < from.len() {
695 match ServiceCapability::decode(&from[idx..]) {
696 Ok(capability) => {
697 idx = idx + capability.encoded_len();
698 capabilities.push(capability);
699 }
700 Err(_) => {
701 info!(
706 "GetCapabilitiesResponse decode: Capability {:?} not supported.",
707 from[idx]
708 );
709 let length_of_capability = from[idx + 1] as usize;
710 idx = idx + 2 + length_of_capability;
711 }
712 }
713 }
714 Ok(GetCapabilitiesResponse { capabilities })
715 }
716}
717
718impl Into<Vec<ServiceCapability>> for GetCapabilitiesResponse {
719 fn into(self) -> Vec<ServiceCapability> {
720 self.capabilities
721 }
722}
723
724#[derive(Debug)]
725pub struct SimpleResponder {
726 peer: Arc<PeerInner>,
727 signal: SignalIdentifier,
728 id: TxLabel,
729}
730
731impl SimpleResponder {
732 pub fn send(self) -> Result<()> {
733 self.peer.send_response(self.id, self.signal, &[])
734 }
735
736 pub fn reject(self, error_code: ErrorCode) -> Result<()> {
737 self.peer.send_reject(self.id, self.signal, error_code)
738 }
739}
740
741#[derive(Debug)]
742pub struct StreamResponder {
743 peer: Arc<PeerInner>,
744 signal: SignalIdentifier,
745 id: TxLabel,
746}
747
748impl StreamResponder {
749 pub fn send(self) -> Result<()> {
750 self.peer.send_response(self.id, self.signal, &[])
751 }
752
753 pub fn reject(self, stream_id: &StreamEndpointId, error_code: ErrorCode) -> Result<()> {
754 self.peer.send_reject_params(
755 self.id,
756 self.signal,
757 &[stream_id.to_msg(), u8::from(&error_code)],
758 )
759 }
760}
761
762#[derive(Debug)]
763pub struct ConfigureResponder {
764 peer: Arc<PeerInner>,
765 signal: SignalIdentifier,
766 id: TxLabel,
767}
768
769impl ConfigureResponder {
770 pub fn send(self) -> Result<()> {
771 self.peer.send_response(self.id, self.signal, &[])
772 }
773
774 pub fn reject(self, category: ServiceCategory, error_code: ErrorCode) -> Result<()> {
775 self.peer.send_reject_params(
776 self.id,
777 self.signal,
778 &[u8::from(&category), u8::from(&error_code)],
779 )
780 }
781}
782
783#[derive(Debug)]
784struct UnparsedRequest(SignalingHeader, Vec<u8>);
785
786impl UnparsedRequest {
787 fn new(header: SignalingHeader, body: Vec<u8>) -> UnparsedRequest {
788 UnparsedRequest(header, body)
789 }
790}
791
792#[derive(Debug, Default)]
793struct RequestQueue {
794 listener: RequestListener,
795 queue: VecDeque<UnparsedRequest>,
796}
797
798#[derive(Debug)]
799enum RequestListener {
800 None,
802 New,
804 Some(Waker),
806}
807
808impl Default for RequestListener {
809 fn default() -> Self {
810 RequestListener::None
811 }
812}
813
814#[derive(Debug)]
816enum ResponseWaiter {
817 WillPoll,
819 Waiting(Waker),
821 Received(Vec<u8>),
824 Discard,
827}
828
829impl ResponseWaiter {
830 fn is_received(&self) -> bool {
832 if let ResponseWaiter::Received(_) = self {
833 true
834 } else {
835 false
836 }
837 }
838
839 fn unwrap_received(self) -> Vec<u8> {
840 if let ResponseWaiter::Received(buf) = self {
841 buf
842 } else {
843 panic!("expected received buf")
844 }
845 }
846}
847
848fn decode_signaling_response<D: Decodable<Error = Error>>(
849 expected_signal: SignalIdentifier,
850 buf: Vec<u8>,
851) -> Result<D> {
852 let header = SignalingHeader::decode(buf.as_slice())?;
853 if header.signal() != expected_signal {
854 return Err(Error::InvalidHeader);
855 }
856 let params = &buf[header.encoded_len()..];
857 match header.message_type {
858 SignalingMessageType::ResponseAccept => D::decode(params),
859 SignalingMessageType::GeneralReject | SignalingMessageType::ResponseReject => {
860 Err(RemoteReject::from_params(header.signal(), params).into())
861 }
862 SignalingMessageType::Command => unreachable!(),
863 }
864}
865
866#[derive(Debug)]
868pub struct CommandResponse {
869 id: TxLabel,
870 inner: Option<Arc<PeerInner>>,
872}
873
874impl Unpin for CommandResponse {}
875
876impl Future for CommandResponse {
877 type Output = Result<Vec<u8>>;
878 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
879 let this = &mut *self;
880 let res;
881 {
882 let client = this.inner.as_ref().ok_or(Error::AlreadyReceived)?;
883 res = client.poll_recv_response(&this.id, cx);
884 }
885
886 if let Poll::Ready(Ok(_)) = res {
887 let inner = this.inner.take().expect("CommandResponse polled after completion");
888 inner.wake_any();
889 }
890
891 res
892 }
893}
894
895impl FusedFuture for CommandResponse {
896 fn is_terminated(&self) -> bool {
897 self.inner.is_none()
898 }
899}
900
901impl Drop for CommandResponse {
902 fn drop(&mut self) {
903 if let Some(inner) = &self.inner {
904 inner.remove_response_interest(&self.id);
905 inner.wake_any();
906 }
907 }
908}
909
910#[derive(Debug)]
911struct PeerInner {
912 signaling: Channel,
914
915 response_waiters: Mutex<Slab<ResponseWaiter>>,
921
922 incoming_requests: Mutex<RequestQueue>,
926}
927
928impl PeerInner {
929 fn add_response_waiter(&self) -> Result<TxLabel> {
932 let key = self.response_waiters.lock().insert(ResponseWaiter::WillPoll);
933 let id = TxLabel::try_from(key as u8);
934 if id.is_err() {
935 warn!("Transaction IDs are exhausted");
936 let _ = self.response_waiters.lock().remove(key);
937 }
938 id
939 }
940
941 fn remove_response_interest(&self, id: &TxLabel) {
944 let mut lock = self.response_waiters.lock();
945 let idx = usize::from(id);
946 if lock[idx].is_received() {
947 let _ = lock.remove(idx);
948 } else {
949 lock[idx] = ResponseWaiter::Discard;
950 }
951 }
952
953 fn poll_recv_request(&self, cx: &mut Context<'_>) -> Poll<Result<UnparsedRequest>> {
958 let is_closed = self.recv_all(cx)?;
959
960 let mut lock = self.incoming_requests.lock();
961
962 if let Some(request) = lock.queue.pop_front() {
963 Poll::Ready(Ok(request))
964 } else {
965 lock.listener = RequestListener::Some(cx.waker().clone());
966 if is_closed {
967 Poll::Ready(Err(Error::PeerDisconnected))
968 } else {
969 Poll::Pending
970 }
971 }
972 }
973
974 fn poll_recv_response(&self, label: &TxLabel, cx: &mut Context<'_>) -> Poll<Result<Vec<u8>>> {
979 let is_closed = self.recv_all(cx)?;
980
981 let mut waiters = self.response_waiters.lock();
982 let idx = usize::from(label);
983 if waiters.get(idx).expect("Polled unregistered waiter").is_received() {
986 let buf = waiters.remove(idx).unwrap_received();
988 Poll::Ready(Ok(buf))
989 } else {
990 *waiters.get_mut(idx).expect("Polled unregistered waiter") =
992 ResponseWaiter::Waiting(cx.waker().clone());
993
994 if is_closed {
995 Poll::Ready(Err(Error::PeerDisconnected))
996 } else {
997 Poll::Pending
998 }
999 }
1000 }
1001
1002 fn recv_all(&self, cx: &mut Context<'_>) -> Result<bool> {
1006 loop {
1007 let mut next_packet = Vec::new();
1008 let packet_size = match self.signaling.poll_datagram(cx, &mut next_packet) {
1009 Poll::Ready(Err(zx::Status::PEER_CLOSED)) => {
1010 trace!("Signaling peer closed");
1011 return Ok(true);
1012 }
1013 Poll::Ready(Err(e)) => return Err(Error::PeerRead(e)),
1014 Poll::Pending => return Ok(false),
1015 Poll::Ready(Ok(size)) => size,
1016 };
1017 if packet_size == 0 {
1018 continue;
1019 }
1020 let header = match SignalingHeader::decode(next_packet.as_slice()) {
1024 Err(Error::InvalidSignalId(label, id)) => {
1025 self.send_general_reject(label, id)?;
1026 continue;
1027 }
1028 Err(_) => {
1029 info!("received unrejectable message");
1032 continue;
1033 }
1034 Ok(x) => x,
1035 };
1036 if header.is_command() {
1038 let mut lock = self.incoming_requests.lock();
1039 let body = next_packet.split_off(header.encoded_len());
1040 lock.queue.push_back(UnparsedRequest::new(header, body));
1041 if let RequestListener::Some(ref waker) = lock.listener {
1042 waker.wake_by_ref();
1043 }
1044 } else {
1045 let mut waiters = self.response_waiters.lock();
1047 let idx = usize::from(&header.label());
1048 if let Some(&ResponseWaiter::Discard) = waiters.get(idx) {
1049 let _ = waiters.remove(idx);
1050 } else if let Some(entry) = waiters.get_mut(idx) {
1051 let old_entry = mem::replace(entry, ResponseWaiter::Received(next_packet));
1052 if let ResponseWaiter::Waiting(waker) = old_entry {
1053 waker.wake();
1054 }
1055 } else {
1056 warn!("response for {:?} we did not send, dropping", header.label());
1057 }
1058 }
1060 }
1061 }
1062
1063 fn wake_any(&self) {
1066 {
1071 let lock = self.response_waiters.lock();
1072 for (_, response_waiter) in lock.iter() {
1073 if let ResponseWaiter::Waiting(waker) = response_waiter {
1074 waker.wake_by_ref();
1075 return;
1076 }
1077 }
1078 }
1079 {
1080 let lock = self.incoming_requests.lock();
1081 if let RequestListener::Some(waker) = &lock.listener {
1082 waker.wake_by_ref();
1083 return;
1084 }
1085 }
1086 }
1087
1088 fn send_general_reject(&self, label: TxLabel, invalid_signal_id: u8) -> Result<()> {
1090 let packet: &[u8; 2] = &[u8::from(&label) << 4 | 0x01, invalid_signal_id & 0x3F];
1093 self.send_signal(packet)
1094 }
1095
1096 fn send_response(&self, label: TxLabel, signal: SignalIdentifier, params: &[u8]) -> Result<()> {
1097 let header = SignalingHeader::new(label, signal, SignalingMessageType::ResponseAccept);
1098 let mut packet = vec![0 as u8; header.encoded_len() + params.len()];
1099 header.encode(packet.as_mut_slice())?;
1100 packet[header.encoded_len()..].clone_from_slice(params);
1101 self.send_signal(&packet)
1102 }
1103
1104 fn send_reject(
1105 &self,
1106 label: TxLabel,
1107 signal: SignalIdentifier,
1108 error_code: ErrorCode,
1109 ) -> Result<()> {
1110 self.send_reject_params(label, signal, &[u8::from(&error_code)])
1111 }
1112
1113 fn send_reject_params(
1114 &self,
1115 label: TxLabel,
1116 signal: SignalIdentifier,
1117 params: &[u8],
1118 ) -> Result<()> {
1119 let header = SignalingHeader::new(label, signal, SignalingMessageType::ResponseReject);
1120 let mut packet = vec![0 as u8; header.encoded_len() + params.len()];
1121 header.encode(packet.as_mut_slice())?;
1122 packet[header.encoded_len()..].clone_from_slice(params);
1123 self.send_signal(&packet)
1124 }
1125
1126 fn send_signal(&self, data: &[u8]) -> Result<()> {
1127 let _ = self.signaling.write(data).map_err(|x| Error::PeerWrite(x))?;
1128 Ok(())
1129 }
1130}