use fuchsia_async::{DurationExt, OnTimeout, TimeoutExt};
use fuchsia_bluetooth::types::Channel;
use fuchsia_sync::Mutex;
use futures::future::{FusedFuture, MaybeDone};
use futures::stream::Stream;
use futures::task::{Context, Poll, Waker};
use futures::{ready, Future, FutureExt, TryFutureExt};
use log::{info, trace, warn};
use packet_encoding::{Decodable, Encodable};
use slab::Slab;
use std::collections::VecDeque;
use std::marker::PhantomData;
use std::mem;
use std::pin::Pin;
use std::sync::Arc;
use zx::{self as zx, MonotonicDuration};
#[cfg(test)]
mod tests;
mod rtp;
mod stream_endpoint;
mod types;
use crate::types::{SignalIdentifier, SignalingHeader, SignalingMessageType, TxLabel};
pub use crate::rtp::{RtpError, RtpHeader};
pub use crate::stream_endpoint::{
MediaStream, StreamEndpoint, StreamEndpointUpdateCallback, StreamState,
};
pub use crate::types::{
ContentProtectionType, EndpointType, Error, ErrorCode, MediaCodecType, MediaType, RemoteReject,
Result, ServiceCapability, ServiceCategory, StreamEndpointId, StreamInformation,
};
#[derive(Debug, Clone)]
pub struct Peer {
inner: Arc<PeerInner>,
}
impl Peer {
pub fn new(signaling: Channel) -> Self {
Self {
inner: Arc::new(PeerInner {
signaling,
response_waiters: Mutex::new(Slab::<ResponseWaiter>::new()),
incoming_requests: Mutex::<RequestQueue>::default(),
}),
}
}
#[track_caller]
pub fn take_request_stream(&self) -> RequestStream {
{
let mut lock = self.inner.incoming_requests.lock();
if let RequestListener::None = lock.listener {
lock.listener = RequestListener::New;
} else {
panic!("Request stream has already been taken");
}
}
RequestStream { inner: self.inner.clone() }
}
pub fn discover(&self) -> impl Future<Output = Result<Vec<StreamInformation>>> {
self.send_command::<DiscoverResponse>(SignalIdentifier::Discover, &[]).ok_into()
}
pub fn get_capabilities(
&self,
stream_id: &StreamEndpointId,
) -> impl Future<Output = Result<Vec<ServiceCapability>>> {
let stream_params = &[stream_id.to_msg()];
self.send_command::<GetCapabilitiesResponse>(
SignalIdentifier::GetCapabilities,
stream_params,
)
.ok_into()
}
pub fn get_all_capabilities(
&self,
stream_id: &StreamEndpointId,
) -> impl Future<Output = Result<Vec<ServiceCapability>>> {
let stream_params = &[stream_id.to_msg()];
self.send_command::<GetCapabilitiesResponse>(
SignalIdentifier::GetAllCapabilities,
stream_params,
)
.ok_into()
}
pub fn set_configuration(
&self,
stream_id: &StreamEndpointId,
local_stream_id: &StreamEndpointId,
capabilities: &[ServiceCapability],
) -> impl Future<Output = Result<()>> {
assert!(!capabilities.is_empty(), "must set at least one capability");
let mut params: Vec<u8> = vec![0; capabilities.iter().fold(2, |a, x| a + x.encoded_len())];
params[0] = stream_id.to_msg();
params[1] = local_stream_id.to_msg();
let mut idx = 2;
for capability in capabilities {
if let Err(e) = capability.encode(&mut params[idx..]) {
return futures::future::err(e).left_future();
}
idx += capability.encoded_len();
}
self.send_command::<SimpleResponse>(SignalIdentifier::SetConfiguration, ¶ms)
.ok_into()
.right_future()
}
pub fn get_configuration(
&self,
stream_id: &StreamEndpointId,
) -> impl Future<Output = Result<Vec<ServiceCapability>>> {
let stream_params = &[stream_id.to_msg()];
self.send_command::<GetCapabilitiesResponse>(
SignalIdentifier::GetConfiguration,
stream_params,
)
.ok_into()
}
pub fn reconfigure(
&self,
stream_id: &StreamEndpointId,
capabilities: &[ServiceCapability],
) -> impl Future<Output = Result<()>> {
assert!(!capabilities.is_empty(), "must set at least one capability");
let mut params: Vec<u8> = vec![0; capabilities.iter().fold(1, |a, x| a + x.encoded_len())];
params[0] = stream_id.to_msg();
let mut idx = 1;
for capability in capabilities {
if !capability.is_application() {
return futures::future::err(Error::Encoding).left_future();
}
if let Err(e) = capability.encode(&mut params[idx..]) {
return futures::future::err(e).left_future();
}
idx += capability.encoded_len();
}
self.send_command::<SimpleResponse>(SignalIdentifier::Reconfigure, ¶ms)
.ok_into()
.right_future()
}
pub fn open(&self, stream_id: &StreamEndpointId) -> impl Future<Output = Result<()>> {
let stream_params = &[stream_id.to_msg()];
self.send_command::<SimpleResponse>(SignalIdentifier::Open, stream_params).ok_into()
}
pub fn start(&self, stream_ids: &[StreamEndpointId]) -> impl Future<Output = Result<()>> {
let mut stream_params = Vec::with_capacity(stream_ids.len());
for stream_id in stream_ids {
stream_params.push(stream_id.to_msg());
}
self.send_command::<SimpleResponse>(SignalIdentifier::Start, &stream_params).ok_into()
}
pub fn close(&self, stream_id: &StreamEndpointId) -> impl Future<Output = Result<()>> {
let stream_params = &[stream_id.to_msg()];
let response: CommandResponseFut<SimpleResponse> =
self.send_command::<SimpleResponse>(SignalIdentifier::Close, stream_params);
response.ok_into()
}
pub fn suspend(&self, stream_ids: &[StreamEndpointId]) -> impl Future<Output = Result<()>> {
let mut stream_params = Vec::with_capacity(stream_ids.len());
for stream_id in stream_ids {
stream_params.push(stream_id.to_msg());
}
let response: CommandResponseFut<SimpleResponse> =
self.send_command::<SimpleResponse>(SignalIdentifier::Suspend, &stream_params);
response.ok_into()
}
pub fn abort(&self, stream_id: &StreamEndpointId) -> impl Future<Output = Result<()>> {
let stream_params = &[stream_id.to_msg()];
self.send_command::<SimpleResponse>(SignalIdentifier::Abort, stream_params).ok_into()
}
pub fn delay_report(
&self,
stream_id: &StreamEndpointId,
delay: u16,
) -> impl Future<Output = Result<()>> {
let delay_bytes: [u8; 2] = delay.to_be_bytes();
let params = &[stream_id.to_msg(), delay_bytes[0], delay_bytes[1]];
self.send_command::<SimpleResponse>(SignalIdentifier::DelayReport, params).ok_into()
}
const RTX_SIG_TIMER_MS: i64 = 3000;
const COMMAND_TIMEOUT: MonotonicDuration =
MonotonicDuration::from_millis(Peer::RTX_SIG_TIMER_MS);
fn send_command<D: Decodable<Error = Error>>(
&self,
signal: SignalIdentifier,
payload: &[u8],
) -> CommandResponseFut<D> {
let send_result = (|| {
let id = self.inner.add_response_waiter()?;
let header = SignalingHeader::new(id, signal, SignalingMessageType::Command);
let mut buf = vec![0; header.encoded_len()];
header.encode(buf.as_mut_slice())?;
buf.extend_from_slice(payload);
self.inner.send_signal(buf.as_slice())?;
Ok(header)
})();
CommandResponseFut::new(send_result, self.inner.clone())
}
}
struct CommandResponseFut<D: Decodable> {
id: SignalIdentifier,
fut: Pin<Box<MaybeDone<OnTimeout<CommandResponse, fn() -> Result<Vec<u8>>>>>>,
_phantom: PhantomData<D>,
}
impl<D: Decodable> Unpin for CommandResponseFut<D> {}
impl<D: Decodable<Error = Error>> CommandResponseFut<D> {
fn new(send_result: Result<SignalingHeader>, inner: Arc<PeerInner>) -> Self {
let header = match send_result {
Err(e) => {
return Self {
id: SignalIdentifier::Abort,
fut: Box::pin(MaybeDone::Done(Err(e))),
_phantom: PhantomData,
}
}
Ok(header) => header,
};
let response = CommandResponse { id: header.label(), inner: Some(inner) };
let err_timeout: fn() -> Result<Vec<u8>> = || Err(Error::Timeout);
let timedout_fut = response.on_timeout(Peer::COMMAND_TIMEOUT.after_now(), err_timeout);
Self {
id: header.signal(),
fut: Box::pin(futures::future::maybe_done(timedout_fut)),
_phantom: PhantomData,
}
}
}
impl<D: Decodable<Error = Error>> Future for CommandResponseFut<D> {
type Output = Result<D>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
ready!(self.fut.poll_unpin(cx));
Poll::Ready(
self.fut
.as_mut()
.take_output()
.unwrap_or(Err(Error::AlreadyReceived))
.and_then(|buf| decode_signaling_response(self.id, buf)),
)
}
}
#[derive(Debug)]
pub enum Request {
Discover {
responder: DiscoverResponder,
},
GetCapabilities {
stream_id: StreamEndpointId,
responder: GetCapabilitiesResponder,
},
GetAllCapabilities {
stream_id: StreamEndpointId,
responder: GetCapabilitiesResponder,
},
SetConfiguration {
local_stream_id: StreamEndpointId,
remote_stream_id: StreamEndpointId,
capabilities: Vec<ServiceCapability>,
responder: ConfigureResponder,
},
GetConfiguration {
stream_id: StreamEndpointId,
responder: GetCapabilitiesResponder,
},
Reconfigure {
local_stream_id: StreamEndpointId,
capabilities: Vec<ServiceCapability>,
responder: ConfigureResponder,
},
Open {
stream_id: StreamEndpointId,
responder: SimpleResponder,
},
Start {
stream_ids: Vec<StreamEndpointId>,
responder: StreamResponder,
},
Close {
stream_id: StreamEndpointId,
responder: SimpleResponder,
},
Suspend {
stream_ids: Vec<StreamEndpointId>,
responder: StreamResponder,
},
Abort {
stream_id: StreamEndpointId,
responder: SimpleResponder,
},
DelayReport {
stream_id: StreamEndpointId,
delay: u16,
responder: SimpleResponder,
}, }
macro_rules! parse_one_seid {
($body:ident, $signal:ident, $peer:ident, $id:ident, $request_variant:ident, $responder_type:ident) => {
if $body.len() != 1 {
Err(Error::RequestInvalid(ErrorCode::BadLength))
} else {
Ok(Request::$request_variant {
stream_id: StreamEndpointId::from_msg(&$body[0]),
responder: $responder_type { signal: $signal, peer: $peer, id: $id },
})
}
};
}
impl Request {
fn get_req_seids(body: &[u8]) -> Result<Vec<StreamEndpointId>> {
if body.len() < 1 {
return Err(Error::RequestInvalid(ErrorCode::BadLength));
}
Ok(body.iter().map(&StreamEndpointId::from_msg).collect())
}
fn get_req_capabilities(encoded: &[u8]) -> Result<Vec<ServiceCapability>> {
if encoded.len() < 2 {
return Err(Error::RequestInvalid(ErrorCode::BadLength));
}
let mut caps = vec![];
let mut loc = 0;
while loc < encoded.len() {
let cap = match ServiceCapability::decode(&encoded[loc..]) {
Ok(cap) => cap,
Err(Error::RequestInvalid(code)) => {
return Err(Error::RequestInvalidExtra(code, encoded[loc]));
}
Err(e) => return Err(e),
};
loc += cap.encoded_len();
caps.push(cap);
}
Ok(caps)
}
fn parse(
peer: Arc<PeerInner>,
id: TxLabel,
signal: SignalIdentifier,
body: &[u8],
) -> Result<Request> {
match signal {
SignalIdentifier::Discover => {
if body.len() > 0 {
return Err(Error::RequestInvalid(ErrorCode::BadLength));
}
Ok(Request::Discover { responder: DiscoverResponder { peer, id } })
}
SignalIdentifier::GetCapabilities => {
parse_one_seid!(body, signal, peer, id, GetCapabilities, GetCapabilitiesResponder)
}
SignalIdentifier::GetAllCapabilities => parse_one_seid!(
body,
signal,
peer,
id,
GetAllCapabilities,
GetCapabilitiesResponder
),
SignalIdentifier::SetConfiguration => {
if body.len() < 4 {
return Err(Error::RequestInvalid(ErrorCode::BadLength));
}
let requested = Request::get_req_capabilities(&body[2..])?;
Ok(Request::SetConfiguration {
local_stream_id: StreamEndpointId::from_msg(&body[0]),
remote_stream_id: StreamEndpointId::from_msg(&body[1]),
capabilities: requested,
responder: ConfigureResponder { signal, peer, id },
})
}
SignalIdentifier::GetConfiguration => {
parse_one_seid!(body, signal, peer, id, GetConfiguration, GetCapabilitiesResponder)
}
SignalIdentifier::Reconfigure => {
if body.len() < 3 {
return Err(Error::RequestInvalid(ErrorCode::BadLength));
}
let requested = Request::get_req_capabilities(&body[1..])?;
match requested.iter().find(|x| !x.is_application()) {
Some(x) => {
return Err(Error::RequestInvalidExtra(
ErrorCode::InvalidCapabilities,
(&x.category()).into(),
));
}
None => (),
};
Ok(Request::Reconfigure {
local_stream_id: StreamEndpointId::from_msg(&body[0]),
capabilities: requested,
responder: ConfigureResponder { signal, peer, id },
})
}
SignalIdentifier::Open => {
parse_one_seid!(body, signal, peer, id, Open, SimpleResponder)
}
SignalIdentifier::Start => {
let seids = Request::get_req_seids(body)?;
Ok(Request::Start {
stream_ids: seids,
responder: StreamResponder { signal, peer, id },
})
}
SignalIdentifier::Close => {
parse_one_seid!(body, signal, peer, id, Close, SimpleResponder)
}
SignalIdentifier::Suspend => {
let seids = Request::get_req_seids(body)?;
Ok(Request::Suspend {
stream_ids: seids,
responder: StreamResponder { signal, peer, id },
})
}
SignalIdentifier::Abort => {
parse_one_seid!(body, signal, peer, id, Abort, SimpleResponder)
}
SignalIdentifier::DelayReport => {
if body.len() != 3 {
return Err(Error::RequestInvalid(ErrorCode::BadLength));
}
let delay_arr: [u8; 2] = [body[1], body[2]];
let delay = u16::from_be_bytes(delay_arr);
Ok(Request::DelayReport {
stream_id: StreamEndpointId::from_msg(&body[0]),
delay,
responder: SimpleResponder { signal, peer, id },
})
}
_ => Err(Error::UnimplementedMessage),
}
}
}
#[derive(Debug)]
pub struct RequestStream {
inner: Arc<PeerInner>,
}
impl Unpin for RequestStream {}
impl Stream for RequestStream {
type Item = Result<Request>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Poll::Ready(match ready!(self.inner.poll_recv_request(cx)) {
Ok(UnparsedRequest(SignalingHeader { label, signal, .. }, body)) => {
match Request::parse(self.inner.clone(), label, signal, &body) {
Err(Error::RequestInvalid(code)) => {
self.inner.send_reject(label, signal, code)?;
return Poll::Pending;
}
Err(Error::RequestInvalidExtra(code, extra)) => {
self.inner.send_reject_params(label, signal, &[extra, u8::from(&code)])?;
return Poll::Pending;
}
Err(Error::UnimplementedMessage) => {
self.inner.send_reject(label, signal, ErrorCode::NotSupportedCommand)?;
return Poll::Pending;
}
x => Some(x),
}
}
Err(Error::PeerDisconnected) => None,
Err(e) => Some(Err(e)),
})
}
}
impl Drop for RequestStream {
fn drop(&mut self) {
self.inner.incoming_requests.lock().listener = RequestListener::None;
self.inner.wake_any();
}
}
#[derive(Debug)]
pub struct SimpleResponse {}
impl Decodable for SimpleResponse {
type Error = Error;
fn decode(from: &[u8]) -> Result<Self> {
if from.len() > 0 {
return Err(Error::InvalidMessage);
}
Ok(SimpleResponse {})
}
}
impl Into<()> for SimpleResponse {
fn into(self) -> () {
()
}
}
#[derive(Debug)]
struct DiscoverResponse {
endpoints: Vec<StreamInformation>,
}
impl Decodable for DiscoverResponse {
type Error = Error;
fn decode(from: &[u8]) -> Result<Self> {
let mut endpoints = Vec::<StreamInformation>::new();
let mut idx = 0;
while idx < from.len() {
let endpoint = StreamInformation::decode(&from[idx..])?;
idx += endpoint.encoded_len();
endpoints.push(endpoint);
}
Ok(DiscoverResponse { endpoints })
}
}
impl Into<Vec<StreamInformation>> for DiscoverResponse {
fn into(self) -> Vec<StreamInformation> {
self.endpoints
}
}
#[derive(Debug)]
pub struct DiscoverResponder {
peer: Arc<PeerInner>,
id: TxLabel,
}
impl DiscoverResponder {
pub fn send(self, endpoints: &[StreamInformation]) -> Result<()> {
if endpoints.len() == 0 {
return Err(Error::Encoding);
}
let mut params = vec![0 as u8; endpoints.len() * endpoints[0].encoded_len()];
let mut idx = 0;
for endpoint in endpoints {
endpoint.encode(&mut params[idx..idx + endpoint.encoded_len()])?;
idx += endpoint.encoded_len();
}
self.peer.send_response(self.id, SignalIdentifier::Discover, ¶ms)
}
pub fn reject(self, error_code: ErrorCode) -> Result<()> {
self.peer.send_reject(self.id, SignalIdentifier::Discover, error_code)
}
}
#[derive(Debug)]
pub struct GetCapabilitiesResponder {
peer: Arc<PeerInner>,
signal: SignalIdentifier,
id: TxLabel,
}
impl GetCapabilitiesResponder {
pub fn send(self, capabilities: &[ServiceCapability]) -> Result<()> {
let included_iter = capabilities.iter().filter(|x| x.in_response(self.signal));
let reply_len = included_iter.clone().fold(0, |a, b| a + b.encoded_len());
let mut reply = vec![0 as u8; reply_len];
let mut pos = 0;
for capability in included_iter {
let size = capability.encoded_len();
capability.encode(&mut reply[pos..pos + size])?;
pos += size;
}
self.peer.send_response(self.id, self.signal, &reply)
}
pub fn reject(self, error_code: ErrorCode) -> Result<()> {
self.peer.send_reject(self.id, self.signal, error_code)
}
}
#[derive(Debug)]
struct GetCapabilitiesResponse {
capabilities: Vec<ServiceCapability>,
}
impl Decodable for GetCapabilitiesResponse {
type Error = Error;
fn decode(from: &[u8]) -> Result<Self> {
let mut capabilities = Vec::<ServiceCapability>::new();
let mut idx = 0;
while idx < from.len() {
match ServiceCapability::decode(&from[idx..]) {
Ok(capability) => {
idx = idx + capability.encoded_len();
capabilities.push(capability);
}
Err(_) => {
info!(
"GetCapabilitiesResponse decode: Capability {:?} not supported.",
from[idx]
);
let length_of_capability = from[idx + 1] as usize;
idx = idx + 2 + length_of_capability;
}
}
}
Ok(GetCapabilitiesResponse { capabilities })
}
}
impl Into<Vec<ServiceCapability>> for GetCapabilitiesResponse {
fn into(self) -> Vec<ServiceCapability> {
self.capabilities
}
}
#[derive(Debug)]
pub struct SimpleResponder {
peer: Arc<PeerInner>,
signal: SignalIdentifier,
id: TxLabel,
}
impl SimpleResponder {
pub fn send(self) -> Result<()> {
self.peer.send_response(self.id, self.signal, &[])
}
pub fn reject(self, error_code: ErrorCode) -> Result<()> {
self.peer.send_reject(self.id, self.signal, error_code)
}
}
#[derive(Debug)]
pub struct StreamResponder {
peer: Arc<PeerInner>,
signal: SignalIdentifier,
id: TxLabel,
}
impl StreamResponder {
pub fn send(self) -> Result<()> {
self.peer.send_response(self.id, self.signal, &[])
}
pub fn reject(self, stream_id: &StreamEndpointId, error_code: ErrorCode) -> Result<()> {
self.peer.send_reject_params(
self.id,
self.signal,
&[stream_id.to_msg(), u8::from(&error_code)],
)
}
}
#[derive(Debug)]
pub struct ConfigureResponder {
peer: Arc<PeerInner>,
signal: SignalIdentifier,
id: TxLabel,
}
impl ConfigureResponder {
pub fn send(self) -> Result<()> {
self.peer.send_response(self.id, self.signal, &[])
}
pub fn reject(self, category: ServiceCategory, error_code: ErrorCode) -> Result<()> {
self.peer.send_reject_params(
self.id,
self.signal,
&[u8::from(&category), u8::from(&error_code)],
)
}
}
#[derive(Debug)]
struct UnparsedRequest(SignalingHeader, Vec<u8>);
impl UnparsedRequest {
fn new(header: SignalingHeader, body: Vec<u8>) -> UnparsedRequest {
UnparsedRequest(header, body)
}
}
#[derive(Debug, Default)]
struct RequestQueue {
listener: RequestListener,
queue: VecDeque<UnparsedRequest>,
}
#[derive(Debug)]
enum RequestListener {
None,
New,
Some(Waker),
}
impl Default for RequestListener {
fn default() -> Self {
RequestListener::None
}
}
#[derive(Debug)]
enum ResponseWaiter {
WillPoll,
Waiting(Waker),
Received(Vec<u8>),
Discard,
}
impl ResponseWaiter {
fn is_received(&self) -> bool {
if let ResponseWaiter::Received(_) = self {
true
} else {
false
}
}
fn unwrap_received(self) -> Vec<u8> {
if let ResponseWaiter::Received(buf) = self {
buf
} else {
panic!("expected received buf")
}
}
}
fn decode_signaling_response<D: Decodable<Error = Error>>(
expected_signal: SignalIdentifier,
buf: Vec<u8>,
) -> Result<D> {
let header = SignalingHeader::decode(buf.as_slice())?;
if header.signal() != expected_signal {
return Err(Error::InvalidHeader);
}
let params = &buf[header.encoded_len()..];
match header.message_type {
SignalingMessageType::ResponseAccept => D::decode(params),
SignalingMessageType::GeneralReject | SignalingMessageType::ResponseReject => {
Err(RemoteReject::from_params(header.signal(), params).into())
}
SignalingMessageType::Command => unreachable!(),
}
}
#[derive(Debug)]
pub struct CommandResponse {
id: TxLabel,
inner: Option<Arc<PeerInner>>,
}
impl Unpin for CommandResponse {}
impl Future for CommandResponse {
type Output = Result<Vec<u8>>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = &mut *self;
let res;
{
let client = this.inner.as_ref().ok_or(Error::AlreadyReceived)?;
res = client.poll_recv_response(&this.id, cx);
}
if let Poll::Ready(Ok(_)) = res {
let inner = this.inner.take().expect("CommandResponse polled after completion");
inner.wake_any();
}
res
}
}
impl FusedFuture for CommandResponse {
fn is_terminated(&self) -> bool {
self.inner.is_none()
}
}
impl Drop for CommandResponse {
fn drop(&mut self) {
if let Some(inner) = &self.inner {
inner.remove_response_interest(&self.id);
inner.wake_any();
}
}
}
#[derive(Debug)]
struct PeerInner {
signaling: Channel,
response_waiters: Mutex<Slab<ResponseWaiter>>,
incoming_requests: Mutex<RequestQueue>,
}
impl PeerInner {
fn add_response_waiter(&self) -> Result<TxLabel> {
let key = self.response_waiters.lock().insert(ResponseWaiter::WillPoll);
let id = TxLabel::try_from(key as u8);
if id.is_err() {
warn!("Transaction IDs are exhausted");
let _ = self.response_waiters.lock().remove(key);
}
id
}
fn remove_response_interest(&self, id: &TxLabel) {
let mut lock = self.response_waiters.lock();
let idx = usize::from(id);
if lock[idx].is_received() {
let _ = lock.remove(idx);
} else {
lock[idx] = ResponseWaiter::Discard;
}
}
fn poll_recv_request(&self, cx: &mut Context<'_>) -> Poll<Result<UnparsedRequest>> {
let is_closed = self.recv_all(cx)?;
let mut lock = self.incoming_requests.lock();
if let Some(request) = lock.queue.pop_front() {
Poll::Ready(Ok(request))
} else {
lock.listener = RequestListener::Some(cx.waker().clone());
if is_closed {
Poll::Ready(Err(Error::PeerDisconnected))
} else {
Poll::Pending
}
}
}
fn poll_recv_response(&self, label: &TxLabel, cx: &mut Context<'_>) -> Poll<Result<Vec<u8>>> {
let is_closed = self.recv_all(cx)?;
let mut waiters = self.response_waiters.lock();
let idx = usize::from(label);
if waiters.get(idx).expect("Polled unregistered waiter").is_received() {
let buf = waiters.remove(idx).unwrap_received();
Poll::Ready(Ok(buf))
} else {
*waiters.get_mut(idx).expect("Polled unregistered waiter") =
ResponseWaiter::Waiting(cx.waker().clone());
if is_closed {
Poll::Ready(Err(Error::PeerDisconnected))
} else {
Poll::Pending
}
}
}
fn recv_all(&self, cx: &mut Context<'_>) -> Result<bool> {
loop {
let mut next_packet = Vec::new();
let packet_size = match self.signaling.poll_datagram(cx, &mut next_packet) {
Poll::Ready(Err(zx::Status::PEER_CLOSED)) => {
trace!("Signaling peer closed");
return Ok(true);
}
Poll::Ready(Err(e)) => return Err(Error::PeerRead(e)),
Poll::Pending => return Ok(false),
Poll::Ready(Ok(size)) => size,
};
if packet_size == 0 {
continue;
}
let header = match SignalingHeader::decode(next_packet.as_slice()) {
Err(Error::InvalidSignalId(label, id)) => {
self.send_general_reject(label, id)?;
continue;
}
Err(_) => {
info!("received unrejectable message");
continue;
}
Ok(x) => x,
};
if header.is_command() {
let mut lock = self.incoming_requests.lock();
let body = next_packet.split_off(header.encoded_len());
lock.queue.push_back(UnparsedRequest::new(header, body));
if let RequestListener::Some(ref waker) = lock.listener {
waker.wake_by_ref();
}
} else {
let mut waiters = self.response_waiters.lock();
let idx = usize::from(&header.label());
if let Some(&ResponseWaiter::Discard) = waiters.get(idx) {
let _ = waiters.remove(idx);
} else if let Some(entry) = waiters.get_mut(idx) {
let old_entry = mem::replace(entry, ResponseWaiter::Received(next_packet));
if let ResponseWaiter::Waiting(waker) = old_entry {
waker.wake();
}
} else {
warn!("response for {:?} we did not send, dropping", header.label());
}
}
}
}
fn wake_any(&self) {
{
let lock = self.response_waiters.lock();
for (_, response_waiter) in lock.iter() {
if let ResponseWaiter::Waiting(waker) = response_waiter {
waker.wake_by_ref();
return;
}
}
}
{
let lock = self.incoming_requests.lock();
if let RequestListener::Some(waker) = &lock.listener {
waker.wake_by_ref();
return;
}
}
}
fn send_general_reject(&self, label: TxLabel, invalid_signal_id: u8) -> Result<()> {
let packet: &[u8; 2] = &[u8::from(&label) << 4 | 0x01, invalid_signal_id & 0x3F];
self.send_signal(packet)
}
fn send_response(&self, label: TxLabel, signal: SignalIdentifier, params: &[u8]) -> Result<()> {
let header = SignalingHeader::new(label, signal, SignalingMessageType::ResponseAccept);
let mut packet = vec![0 as u8; header.encoded_len() + params.len()];
header.encode(packet.as_mut_slice())?;
packet[header.encoded_len()..].clone_from_slice(params);
self.send_signal(&packet)
}
fn send_reject(
&self,
label: TxLabel,
signal: SignalIdentifier,
error_code: ErrorCode,
) -> Result<()> {
self.send_reject_params(label, signal, &[u8::from(&error_code)])
}
fn send_reject_params(
&self,
label: TxLabel,
signal: SignalIdentifier,
params: &[u8],
) -> Result<()> {
let header = SignalingHeader::new(label, signal, SignalingMessageType::ResponseReject);
let mut packet = vec![0 as u8; header.encoded_len() + params.len()];
header.encode(packet.as_mut_slice())?;
packet[header.encoded_len()..].clone_from_slice(params);
self.send_signal(&packet)
}
fn send_signal(&self, data: &[u8]) -> Result<()> {
let _ = self.signaling.write(data).map_err(|x| Error::PeerWrite(x))?;
Ok(())
}
}