bt_avdtp/
lib.rs

1// Copyright 2018 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use fuchsia_async::{DurationExt, OnTimeout, TimeoutExt};
6use fuchsia_bluetooth::types::Channel;
7use fuchsia_sync::Mutex;
8use futures::future::{FusedFuture, MaybeDone};
9use futures::stream::Stream;
10use futures::task::{Context, Poll, Waker};
11use futures::{ready, Future, FutureExt, TryFutureExt};
12use log::{info, trace, warn};
13use packet_encoding::{Decodable, Encodable};
14use slab::Slab;
15use std::collections::VecDeque;
16use std::marker::PhantomData;
17use std::mem;
18use std::pin::Pin;
19use std::sync::Arc;
20use zx::{self as zx, MonotonicDuration};
21
22#[cfg(test)]
23mod tests;
24
25mod rtp;
26mod stream_endpoint;
27mod types;
28
29use crate::types::{SignalIdentifier, SignalingHeader, SignalingMessageType, TxLabel};
30
31pub use crate::rtp::{RtpError, RtpHeader};
32pub use crate::stream_endpoint::{
33    MediaStream, StreamEndpoint, StreamEndpointUpdateCallback, StreamState,
34};
35pub use crate::types::{
36    ContentProtectionType, EndpointType, Error, ErrorCode, MediaCodecType, MediaType, RemoteReject,
37    Result, ServiceCapability, ServiceCategory, StreamEndpointId, StreamInformation,
38};
39
40/// An AVDTP signaling peer can send commands to another peer, receive requests and send responses.
41/// Media transport is not handled by this peer.
42///
43/// Requests from the distant peer are delivered through the request stream available through
44/// take_request_stream().  Only one RequestStream can be active at a time.  Only valid requests
45/// are sent to the request stream - invalid formats are automatically rejected.
46///
47/// Responses are sent using responders that are included in the request stream from the connected
48/// peer.
49#[derive(Debug, Clone)]
50pub struct Peer {
51    inner: Arc<PeerInner>,
52}
53
54impl Peer {
55    /// Create a new peer from a signaling channel.
56    pub fn new(signaling: Channel) -> Self {
57        Self {
58            inner: Arc::new(PeerInner {
59                signaling,
60                response_waiters: Mutex::new(Slab::<ResponseWaiter>::new()),
61                incoming_requests: Mutex::<RequestQueue>::default(),
62            }),
63        }
64    }
65
66    /// Take the event listener for this peer. Panics if the stream is already
67    /// held.
68    #[track_caller]
69    pub fn take_request_stream(&self) -> RequestStream {
70        {
71            let mut lock = self.inner.incoming_requests.lock();
72            if let RequestListener::None = lock.listener {
73                lock.listener = RequestListener::New;
74            } else {
75                panic!("Request stream has already been taken");
76            }
77        }
78
79        RequestStream { inner: self.inner.clone() }
80    }
81
82    /// Send a Stream End Point Discovery (Sec 8.6) command to the remote peer.
83    /// Asynchronously returns a the reply in a vector of endpoint information.
84    /// Error will be RemoteRejected if the remote peer rejected the command.
85    pub fn discover(&self) -> impl Future<Output = Result<Vec<StreamInformation>>> {
86        self.send_command::<DiscoverResponse>(SignalIdentifier::Discover, &[]).ok_into()
87    }
88
89    /// Send a Get Capabilities (Sec 8.7) command to the remote peer for the
90    /// given `stream_id`.
91    /// Asynchronously returns the reply which contains the ServiceCapabilities
92    /// reported.
93    /// In general, Get All Capabilities should be preferred to this command if is supported.
94    /// Error will be RemoteRejected if the remote peer rejects the command.
95    pub fn get_capabilities(
96        &self,
97        stream_id: &StreamEndpointId,
98    ) -> impl Future<Output = Result<Vec<ServiceCapability>>> {
99        let stream_params = &[stream_id.to_msg()];
100        self.send_command::<GetCapabilitiesResponse>(
101            SignalIdentifier::GetCapabilities,
102            stream_params,
103        )
104        .ok_into()
105    }
106
107    /// Send a Get All Capabilities (Sec 8.8) command to the remote peer for the
108    /// given `stream_id`.
109    /// Asynchronously returns the reply which contains the ServiceCapabilities
110    /// reported.
111    /// Error will be RemoteRejected if the remote peer rejects the command.
112    pub fn get_all_capabilities(
113        &self,
114        stream_id: &StreamEndpointId,
115    ) -> impl Future<Output = Result<Vec<ServiceCapability>>> {
116        let stream_params = &[stream_id.to_msg()];
117        self.send_command::<GetCapabilitiesResponse>(
118            SignalIdentifier::GetAllCapabilities,
119            stream_params,
120        )
121        .ok_into()
122    }
123
124    /// Send a Stream Configuration (Sec 8.9) command to the remote peer for the
125    /// given remote `stream_id`, communicating the association to a local
126    /// `local_stream_id` and the required stream `capabilities`.
127    /// Panics if `capabilities` is empty.
128    /// Error will be RemoteRejected if the remote refused.
129    /// ServiceCategory will be set on RemoteReject with the indicated issue category.
130    pub fn set_configuration(
131        &self,
132        stream_id: &StreamEndpointId,
133        local_stream_id: &StreamEndpointId,
134        capabilities: &[ServiceCapability],
135    ) -> impl Future<Output = Result<()>> {
136        assert!(!capabilities.is_empty(), "must set at least one capability");
137        let mut params: Vec<u8> = vec![0; capabilities.iter().fold(2, |a, x| a + x.encoded_len())];
138        params[0] = stream_id.to_msg();
139        params[1] = local_stream_id.to_msg();
140        let mut idx = 2;
141        for capability in capabilities {
142            if let Err(e) = capability.encode(&mut params[idx..]) {
143                return futures::future::err(e).left_future();
144            }
145            idx += capability.encoded_len();
146        }
147        self.send_command::<SimpleResponse>(SignalIdentifier::SetConfiguration, &params)
148            .ok_into()
149            .right_future()
150    }
151
152    /// Send a Get Stream Configuration (Sec 8.10) command to the remote peer
153    /// for the given remote `stream_id`.
154    /// Asynchronously returns the set of ServiceCapabilities previously
155    /// configured between these two peers.
156    /// Error will be RemoteRejected if the remote peer rejects this command.
157    pub fn get_configuration(
158        &self,
159        stream_id: &StreamEndpointId,
160    ) -> impl Future<Output = Result<Vec<ServiceCapability>>> {
161        let stream_params = &[stream_id.to_msg()];
162        self.send_command::<GetCapabilitiesResponse>(
163            SignalIdentifier::GetConfiguration,
164            stream_params,
165        )
166        .ok_into()
167    }
168
169    /// Send a Stream Reconfigure (Sec 8.11) command to the remote peer for the
170    /// given remote `stream_id`, to reconfigure the Application Service
171    /// capabilities in `capabilities`.
172    /// Note: Per the spec, only the Media Codec and Content Protection
173    /// capabilities will be accepted in this command.
174    /// Panics if there are no capabilities to configure.
175    /// Error will be RemoteRejected if the remote refused.
176    /// ServiceCategory will be set on RemoteReject with the indicated issue category.
177    pub fn reconfigure(
178        &self,
179        stream_id: &StreamEndpointId,
180        capabilities: &[ServiceCapability],
181    ) -> impl Future<Output = Result<()>> {
182        assert!(!capabilities.is_empty(), "must set at least one capability");
183        let mut params: Vec<u8> = vec![0; capabilities.iter().fold(1, |a, x| a + x.encoded_len())];
184        params[0] = stream_id.to_msg();
185        let mut idx = 1;
186        for capability in capabilities {
187            if !capability.is_application() {
188                return futures::future::err(Error::Encoding).left_future();
189            }
190            if let Err(e) = capability.encode(&mut params[idx..]) {
191                return futures::future::err(e).left_future();
192            }
193            idx += capability.encoded_len();
194        }
195        self.send_command::<SimpleResponse>(SignalIdentifier::Reconfigure, &params)
196            .ok_into()
197            .right_future()
198    }
199
200    /// Send a Open Stream Command (Sec 8.12) to the remote peer for the given
201    /// `stream_id`.
202    /// Error will be RemoteRejected if the remote peer rejects the command.
203    pub fn open(&self, stream_id: &StreamEndpointId) -> impl Future<Output = Result<()>> {
204        let stream_params = &[stream_id.to_msg()];
205        self.send_command::<SimpleResponse>(SignalIdentifier::Open, stream_params).ok_into()
206    }
207
208    /// Send a Start Stream Command (Sec 8.13) to the remote peer for all the streams in
209    /// `stream_ids`.
210    /// Returns Ok(()) if the command is accepted, and RemoteStreamRejected with the stream
211    /// endpoint id and error code reported by the remote if the remote signals a failure.
212    pub fn start(&self, stream_ids: &[StreamEndpointId]) -> impl Future<Output = Result<()>> {
213        let mut stream_params = Vec::with_capacity(stream_ids.len());
214        for stream_id in stream_ids {
215            stream_params.push(stream_id.to_msg());
216        }
217        self.send_command::<SimpleResponse>(SignalIdentifier::Start, &stream_params).ok_into()
218    }
219
220    /// Send a Close Stream Command (Sec 8.14) to the remote peer for the given `stream_id`.
221    /// Error will be RemoteRejected if the remote peer rejects the command.
222    pub fn close(&self, stream_id: &StreamEndpointId) -> impl Future<Output = Result<()>> {
223        let stream_params = &[stream_id.to_msg()];
224        let response: CommandResponseFut<SimpleResponse> =
225            self.send_command::<SimpleResponse>(SignalIdentifier::Close, stream_params);
226        response.ok_into()
227    }
228
229    /// Send a Suspend Command (Sec 8.15) to the remote peer for all the streams in `stream_ids`.
230    /// Error will be RemoteRejected if the remote refused, with the stream endpoint identifier
231    /// indicated by the remote set in the RemoteReject.
232    pub fn suspend(&self, stream_ids: &[StreamEndpointId]) -> impl Future<Output = Result<()>> {
233        let mut stream_params = Vec::with_capacity(stream_ids.len());
234        for stream_id in stream_ids {
235            stream_params.push(stream_id.to_msg());
236        }
237        let response: CommandResponseFut<SimpleResponse> =
238            self.send_command::<SimpleResponse>(SignalIdentifier::Suspend, &stream_params);
239        response.ok_into()
240    }
241
242    /// Send an Abort (Sec 8.16) to the remote peer for the given `stream_id`.
243    /// Returns Ok(()) if the command is accepted, and Err(Timeout) if the remote
244    /// timed out.  The remote peer is not allowed to reject this command, and
245    /// commands that have invalid `stream_id` will timeout instead.
246    pub fn abort(&self, stream_id: &StreamEndpointId) -> impl Future<Output = Result<()>> {
247        let stream_params = &[stream_id.to_msg()];
248        self.send_command::<SimpleResponse>(SignalIdentifier::Abort, stream_params).ok_into()
249    }
250
251    /// Send a Delay Report (Sec 8.19) to the remote peer for the given `stream_id`.
252    /// `delay` is in tenths of milliseconds.
253    /// Error will be RemoteRejected if the remote peer rejects the command.
254    pub fn delay_report(
255        &self,
256        stream_id: &StreamEndpointId,
257        delay: u16,
258    ) -> impl Future<Output = Result<()>> {
259        let delay_bytes: [u8; 2] = delay.to_be_bytes();
260        let params = &[stream_id.to_msg(), delay_bytes[0], delay_bytes[1]];
261        self.send_command::<SimpleResponse>(SignalIdentifier::DelayReport, params).ok_into()
262    }
263
264    /// The maximum amount of time we will wait for a response to a signaling command.
265    const RTX_SIG_TIMER_MS: i64 = 3000;
266    const COMMAND_TIMEOUT: MonotonicDuration =
267        MonotonicDuration::from_millis(Peer::RTX_SIG_TIMER_MS);
268
269    /// Sends a signal on the channel and receive a future that will complete
270    /// when we get the expected response.
271    fn send_command<D: Decodable<Error = Error>>(
272        &self,
273        signal: SignalIdentifier,
274        payload: &[u8],
275    ) -> CommandResponseFut<D> {
276        let send_result = (|| {
277            let id = self.inner.add_response_waiter()?;
278            let header = SignalingHeader::new(id, signal, SignalingMessageType::Command);
279            let mut buf = vec![0; header.encoded_len()];
280            header.encode(buf.as_mut_slice())?;
281            buf.extend_from_slice(payload);
282            self.inner.send_signal(buf.as_slice())?;
283            Ok(header)
284        })();
285
286        CommandResponseFut::new(send_result, self.inner.clone())
287    }
288}
289
290/// A future representing the result of a AVDTP command. Decodes the response when it arrives.
291struct CommandResponseFut<D: Decodable> {
292    id: SignalIdentifier,
293    fut: Pin<Box<MaybeDone<OnTimeout<CommandResponse, fn() -> Result<Vec<u8>>>>>>,
294    _phantom: PhantomData<D>,
295}
296
297impl<D: Decodable> Unpin for CommandResponseFut<D> {}
298
299impl<D: Decodable<Error = Error>> CommandResponseFut<D> {
300    fn new(send_result: Result<SignalingHeader>, inner: Arc<PeerInner>) -> Self {
301        let header = match send_result {
302            Err(e) => {
303                return Self {
304                    id: SignalIdentifier::Abort,
305                    fut: Box::pin(MaybeDone::Done(Err(e))),
306                    _phantom: PhantomData,
307                }
308            }
309            Ok(header) => header,
310        };
311        let response = CommandResponse { id: header.label(), inner: Some(inner) };
312        let err_timeout: fn() -> Result<Vec<u8>> = || Err(Error::Timeout);
313        let timedout_fut = response.on_timeout(Peer::COMMAND_TIMEOUT.after_now(), err_timeout);
314
315        Self {
316            id: header.signal(),
317            fut: Box::pin(futures::future::maybe_done(timedout_fut)),
318            _phantom: PhantomData,
319        }
320    }
321}
322
323impl<D: Decodable<Error = Error>> Future for CommandResponseFut<D> {
324    type Output = Result<D>;
325
326    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
327        ready!(self.fut.poll_unpin(cx));
328        Poll::Ready(
329            self.fut
330                .as_mut()
331                .take_output()
332                .unwrap_or(Err(Error::AlreadyReceived))
333                .and_then(|buf| decode_signaling_response(self.id, buf)),
334        )
335    }
336}
337
338/// A request from the connected peer.
339/// Each variant of this includes a responder which implements two functions:
340///  - send(...) will send a response with the information provided.
341///  - reject(ErrorCode) will send an reject response with the given error code.
342#[derive(Debug)]
343pub enum Request {
344    Discover {
345        responder: DiscoverResponder,
346    },
347    GetCapabilities {
348        stream_id: StreamEndpointId,
349        responder: GetCapabilitiesResponder,
350    },
351    GetAllCapabilities {
352        stream_id: StreamEndpointId,
353        responder: GetCapabilitiesResponder,
354    },
355    SetConfiguration {
356        local_stream_id: StreamEndpointId,
357        remote_stream_id: StreamEndpointId,
358        capabilities: Vec<ServiceCapability>,
359        responder: ConfigureResponder,
360    },
361    GetConfiguration {
362        stream_id: StreamEndpointId,
363        responder: GetCapabilitiesResponder,
364    },
365    Reconfigure {
366        local_stream_id: StreamEndpointId,
367        capabilities: Vec<ServiceCapability>,
368        responder: ConfigureResponder,
369    },
370    Open {
371        stream_id: StreamEndpointId,
372        responder: SimpleResponder,
373    },
374    Start {
375        stream_ids: Vec<StreamEndpointId>,
376        responder: StreamResponder,
377    },
378    Close {
379        stream_id: StreamEndpointId,
380        responder: SimpleResponder,
381    },
382    Suspend {
383        stream_ids: Vec<StreamEndpointId>,
384        responder: StreamResponder,
385    },
386    Abort {
387        stream_id: StreamEndpointId,
388        responder: SimpleResponder,
389    },
390    DelayReport {
391        stream_id: StreamEndpointId,
392        delay: u16,
393        responder: SimpleResponder,
394    }, // TODO(jamuraa): add the rest of the requests
395}
396
397macro_rules! parse_one_seid {
398    ($body:ident, $signal:ident, $peer:ident, $id:ident, $request_variant:ident, $responder_type:ident) => {
399        if $body.len() != 1 {
400            Err(Error::RequestInvalid(ErrorCode::BadLength))
401        } else {
402            Ok(Request::$request_variant {
403                stream_id: StreamEndpointId::from_msg(&$body[0]),
404                responder: $responder_type { signal: $signal, peer: $peer, id: $id },
405            })
406        }
407    };
408}
409
410impl Request {
411    fn get_req_seids(body: &[u8]) -> Result<Vec<StreamEndpointId>> {
412        if body.len() < 1 {
413            return Err(Error::RequestInvalid(ErrorCode::BadLength));
414        }
415        Ok(body.iter().map(&StreamEndpointId::from_msg).collect())
416    }
417
418    fn get_req_capabilities(encoded: &[u8]) -> Result<Vec<ServiceCapability>> {
419        if encoded.len() < 2 {
420            return Err(Error::RequestInvalid(ErrorCode::BadLength));
421        }
422        let mut caps = vec![];
423        let mut loc = 0;
424        while loc < encoded.len() {
425            let cap = match ServiceCapability::decode(&encoded[loc..]) {
426                Ok(cap) => cap,
427                Err(Error::RequestInvalid(code)) => {
428                    return Err(Error::RequestInvalidExtra(code, encoded[loc]));
429                }
430                Err(e) => return Err(e),
431            };
432            loc += cap.encoded_len();
433            caps.push(cap);
434        }
435        Ok(caps)
436    }
437
438    fn parse(
439        peer: Arc<PeerInner>,
440        id: TxLabel,
441        signal: SignalIdentifier,
442        body: &[u8],
443    ) -> Result<Request> {
444        match signal {
445            SignalIdentifier::Discover => {
446                // Discover Request has no body (Sec 8.6.1)
447                if body.len() > 0 {
448                    return Err(Error::RequestInvalid(ErrorCode::BadLength));
449                }
450                Ok(Request::Discover { responder: DiscoverResponder { peer, id } })
451            }
452            SignalIdentifier::GetCapabilities => {
453                parse_one_seid!(body, signal, peer, id, GetCapabilities, GetCapabilitiesResponder)
454            }
455            SignalIdentifier::GetAllCapabilities => parse_one_seid!(
456                body,
457                signal,
458                peer,
459                id,
460                GetAllCapabilities,
461                GetCapabilitiesResponder
462            ),
463            SignalIdentifier::SetConfiguration => {
464                if body.len() < 4 {
465                    return Err(Error::RequestInvalid(ErrorCode::BadLength));
466                }
467                let requested = Request::get_req_capabilities(&body[2..])?;
468                Ok(Request::SetConfiguration {
469                    local_stream_id: StreamEndpointId::from_msg(&body[0]),
470                    remote_stream_id: StreamEndpointId::from_msg(&body[1]),
471                    capabilities: requested,
472                    responder: ConfigureResponder { signal, peer, id },
473                })
474            }
475            SignalIdentifier::GetConfiguration => {
476                parse_one_seid!(body, signal, peer, id, GetConfiguration, GetCapabilitiesResponder)
477            }
478            SignalIdentifier::Reconfigure => {
479                if body.len() < 3 {
480                    return Err(Error::RequestInvalid(ErrorCode::BadLength));
481                }
482                let requested = Request::get_req_capabilities(&body[1..])?;
483                match requested.iter().find(|x| !x.is_application()) {
484                    Some(x) => {
485                        return Err(Error::RequestInvalidExtra(
486                            ErrorCode::InvalidCapabilities,
487                            (&x.category()).into(),
488                        ));
489                    }
490                    None => (),
491                };
492                Ok(Request::Reconfigure {
493                    local_stream_id: StreamEndpointId::from_msg(&body[0]),
494                    capabilities: requested,
495                    responder: ConfigureResponder { signal, peer, id },
496                })
497            }
498            SignalIdentifier::Open => {
499                parse_one_seid!(body, signal, peer, id, Open, SimpleResponder)
500            }
501            SignalIdentifier::Start => {
502                let seids = Request::get_req_seids(body)?;
503                Ok(Request::Start {
504                    stream_ids: seids,
505                    responder: StreamResponder { signal, peer, id },
506                })
507            }
508            SignalIdentifier::Close => {
509                parse_one_seid!(body, signal, peer, id, Close, SimpleResponder)
510            }
511            SignalIdentifier::Suspend => {
512                let seids = Request::get_req_seids(body)?;
513                Ok(Request::Suspend {
514                    stream_ids: seids,
515                    responder: StreamResponder { signal, peer, id },
516                })
517            }
518            SignalIdentifier::Abort => {
519                parse_one_seid!(body, signal, peer, id, Abort, SimpleResponder)
520            }
521            SignalIdentifier::DelayReport => {
522                if body.len() != 3 {
523                    return Err(Error::RequestInvalid(ErrorCode::BadLength));
524                }
525                let delay_arr: [u8; 2] = [body[1], body[2]];
526                let delay = u16::from_be_bytes(delay_arr);
527                Ok(Request::DelayReport {
528                    stream_id: StreamEndpointId::from_msg(&body[0]),
529                    delay,
530                    responder: SimpleResponder { signal, peer, id },
531                })
532            }
533            _ => Err(Error::UnimplementedMessage),
534        }
535    }
536}
537
538/// A stream of requests from the remote peer.
539#[derive(Debug)]
540pub struct RequestStream {
541    inner: Arc<PeerInner>,
542}
543
544impl Unpin for RequestStream {}
545
546impl Stream for RequestStream {
547    type Item = Result<Request>;
548
549    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
550        Poll::Ready(match ready!(self.inner.poll_recv_request(cx)) {
551            Ok(UnparsedRequest(SignalingHeader { label, signal, .. }, body)) => {
552                match Request::parse(self.inner.clone(), label, signal, &body) {
553                    Err(Error::RequestInvalid(code)) => {
554                        self.inner.send_reject(label, signal, code)?;
555                        return Poll::Pending;
556                    }
557                    Err(Error::RequestInvalidExtra(code, extra)) => {
558                        self.inner.send_reject_params(label, signal, &[extra, u8::from(&code)])?;
559                        return Poll::Pending;
560                    }
561                    Err(Error::UnimplementedMessage) => {
562                        self.inner.send_reject(label, signal, ErrorCode::NotSupportedCommand)?;
563                        return Poll::Pending;
564                    }
565                    x => Some(x),
566                }
567            }
568            Err(Error::PeerDisconnected) => None,
569            Err(e) => Some(Err(e)),
570        })
571    }
572}
573
574impl Drop for RequestStream {
575    fn drop(&mut self) {
576        self.inner.incoming_requests.lock().listener = RequestListener::None;
577        self.inner.wake_any();
578    }
579}
580
581// Simple responses have no body data.
582#[derive(Debug)]
583pub struct SimpleResponse {}
584
585impl Decodable for SimpleResponse {
586    type Error = Error;
587
588    fn decode(from: &[u8]) -> Result<Self> {
589        if from.len() > 0 {
590            return Err(Error::InvalidMessage);
591        }
592        Ok(SimpleResponse {})
593    }
594}
595
596impl Into<()> for SimpleResponse {
597    fn into(self) -> () {
598        ()
599    }
600}
601
602#[derive(Debug)]
603struct DiscoverResponse {
604    endpoints: Vec<StreamInformation>,
605}
606
607impl Decodable for DiscoverResponse {
608    type Error = Error;
609
610    fn decode(from: &[u8]) -> Result<Self> {
611        let mut endpoints = Vec::<StreamInformation>::new();
612        let mut idx = 0;
613        while idx < from.len() {
614            let endpoint = StreamInformation::decode(&from[idx..])?;
615            idx += endpoint.encoded_len();
616            endpoints.push(endpoint);
617        }
618        Ok(DiscoverResponse { endpoints })
619    }
620}
621
622impl Into<Vec<StreamInformation>> for DiscoverResponse {
623    fn into(self) -> Vec<StreamInformation> {
624        self.endpoints
625    }
626}
627
628#[derive(Debug)]
629pub struct DiscoverResponder {
630    peer: Arc<PeerInner>,
631    id: TxLabel,
632}
633
634impl DiscoverResponder {
635    /// Sends the response to a discovery request.
636    /// At least one endpoint must be present.
637    /// Will result in a Error::PeerWrite if the distant peer is disconnected.
638    pub fn send(self, endpoints: &[StreamInformation]) -> Result<()> {
639        if endpoints.len() == 0 {
640            // There shall be at least one SEP in a response (Sec 8.6.2)
641            return Err(Error::Encoding);
642        }
643        let mut params = vec![0 as u8; endpoints.len() * endpoints[0].encoded_len()];
644        let mut idx = 0;
645        for endpoint in endpoints {
646            endpoint.encode(&mut params[idx..idx + endpoint.encoded_len()])?;
647            idx += endpoint.encoded_len();
648        }
649        self.peer.send_response(self.id, SignalIdentifier::Discover, &params)
650    }
651
652    pub fn reject(self, error_code: ErrorCode) -> Result<()> {
653        self.peer.send_reject(self.id, SignalIdentifier::Discover, error_code)
654    }
655}
656
657#[derive(Debug)]
658pub struct GetCapabilitiesResponder {
659    peer: Arc<PeerInner>,
660    signal: SignalIdentifier,
661    id: TxLabel,
662}
663
664impl GetCapabilitiesResponder {
665    pub fn send(self, capabilities: &[ServiceCapability]) -> Result<()> {
666        let included_iter = capabilities.iter().filter(|x| x.in_response(self.signal));
667        let reply_len = included_iter.clone().fold(0, |a, b| a + b.encoded_len());
668        let mut reply = vec![0 as u8; reply_len];
669        let mut pos = 0;
670        for capability in included_iter {
671            let size = capability.encoded_len();
672            capability.encode(&mut reply[pos..pos + size])?;
673            pos += size;
674        }
675        self.peer.send_response(self.id, self.signal, &reply)
676    }
677
678    pub fn reject(self, error_code: ErrorCode) -> Result<()> {
679        self.peer.send_reject(self.id, self.signal, error_code)
680    }
681}
682
683#[derive(Debug)]
684struct GetCapabilitiesResponse {
685    capabilities: Vec<ServiceCapability>,
686}
687
688impl Decodable for GetCapabilitiesResponse {
689    type Error = Error;
690
691    fn decode(from: &[u8]) -> Result<Self> {
692        let mut capabilities = Vec::<ServiceCapability>::new();
693        let mut idx = 0;
694        while idx < from.len() {
695            match ServiceCapability::decode(&from[idx..]) {
696                Ok(capability) => {
697                    idx = idx + capability.encoded_len();
698                    capabilities.push(capability);
699                }
700                Err(_) => {
701                    // The capability length of the invalid capability can be nonzero.
702                    // Advance `idx` by the payload amount, but don't push the invalid capability.
703                    // Increment by 1 byte for ServiceCategory, 1 byte for payload length,
704                    // `length_of_capability` bytes for capability length.
705                    info!(
706                        "GetCapabilitiesResponse decode: Capability {:?} not supported.",
707                        from[idx]
708                    );
709                    let length_of_capability = from[idx + 1] as usize;
710                    idx = idx + 2 + length_of_capability;
711                }
712            }
713        }
714        Ok(GetCapabilitiesResponse { capabilities })
715    }
716}
717
718impl Into<Vec<ServiceCapability>> for GetCapabilitiesResponse {
719    fn into(self) -> Vec<ServiceCapability> {
720        self.capabilities
721    }
722}
723
724#[derive(Debug)]
725pub struct SimpleResponder {
726    peer: Arc<PeerInner>,
727    signal: SignalIdentifier,
728    id: TxLabel,
729}
730
731impl SimpleResponder {
732    pub fn send(self) -> Result<()> {
733        self.peer.send_response(self.id, self.signal, &[])
734    }
735
736    pub fn reject(self, error_code: ErrorCode) -> Result<()> {
737        self.peer.send_reject(self.id, self.signal, error_code)
738    }
739}
740
741#[derive(Debug)]
742pub struct StreamResponder {
743    peer: Arc<PeerInner>,
744    signal: SignalIdentifier,
745    id: TxLabel,
746}
747
748impl StreamResponder {
749    pub fn send(self) -> Result<()> {
750        self.peer.send_response(self.id, self.signal, &[])
751    }
752
753    pub fn reject(self, stream_id: &StreamEndpointId, error_code: ErrorCode) -> Result<()> {
754        self.peer.send_reject_params(
755            self.id,
756            self.signal,
757            &[stream_id.to_msg(), u8::from(&error_code)],
758        )
759    }
760}
761
762#[derive(Debug)]
763pub struct ConfigureResponder {
764    peer: Arc<PeerInner>,
765    signal: SignalIdentifier,
766    id: TxLabel,
767}
768
769impl ConfigureResponder {
770    pub fn send(self) -> Result<()> {
771        self.peer.send_response(self.id, self.signal, &[])
772    }
773
774    pub fn reject(self, category: ServiceCategory, error_code: ErrorCode) -> Result<()> {
775        self.peer.send_reject_params(
776            self.id,
777            self.signal,
778            &[u8::from(&category), u8::from(&error_code)],
779        )
780    }
781}
782
783#[derive(Debug)]
784struct UnparsedRequest(SignalingHeader, Vec<u8>);
785
786impl UnparsedRequest {
787    fn new(header: SignalingHeader, body: Vec<u8>) -> UnparsedRequest {
788        UnparsedRequest(header, body)
789    }
790}
791
792#[derive(Debug, Default)]
793struct RequestQueue {
794    listener: RequestListener,
795    queue: VecDeque<UnparsedRequest>,
796}
797
798#[derive(Debug)]
799enum RequestListener {
800    /// No one is listening.
801    None,
802    /// Someone wants to listen but hasn't polled.
803    New,
804    /// Someone is listening, and can be woken with the waker.
805    Some(Waker),
806}
807
808impl Default for RequestListener {
809    fn default() -> Self {
810        RequestListener::None
811    }
812}
813
814/// An enum representing an interest in the response to a command.
815#[derive(Debug)]
816enum ResponseWaiter {
817    /// A new waiter which hasn't been polled yet.
818    WillPoll,
819    /// A task waiting for a response, which can be woken with the waker.
820    Waiting(Waker),
821    /// A response that has been received, stored here until it's polled, at
822    /// which point it will be decoded.
823    Received(Vec<u8>),
824    /// It's still waiting on the response, but the receiver has decided they
825    /// don't care and we'll throw it out.
826    Discard,
827}
828
829impl ResponseWaiter {
830    /// Check if a message has been received.
831    fn is_received(&self) -> bool {
832        if let ResponseWaiter::Received(_) = self {
833            true
834        } else {
835            false
836        }
837    }
838
839    fn unwrap_received(self) -> Vec<u8> {
840        if let ResponseWaiter::Received(buf) = self {
841            buf
842        } else {
843            panic!("expected received buf")
844        }
845    }
846}
847
848fn decode_signaling_response<D: Decodable<Error = Error>>(
849    expected_signal: SignalIdentifier,
850    buf: Vec<u8>,
851) -> Result<D> {
852    let header = SignalingHeader::decode(buf.as_slice())?;
853    if header.signal() != expected_signal {
854        return Err(Error::InvalidHeader);
855    }
856    let params = &buf[header.encoded_len()..];
857    match header.message_type {
858        SignalingMessageType::ResponseAccept => D::decode(params),
859        SignalingMessageType::GeneralReject | SignalingMessageType::ResponseReject => {
860            Err(RemoteReject::from_params(header.signal(), params).into())
861        }
862        SignalingMessageType::Command => unreachable!(),
863    }
864}
865
866/// A future that polls for the response to a command we sent.
867#[derive(Debug)]
868pub struct CommandResponse {
869    id: TxLabel,
870    // Some(x) if we're still waiting on the response.
871    inner: Option<Arc<PeerInner>>,
872}
873
874impl Unpin for CommandResponse {}
875
876impl Future for CommandResponse {
877    type Output = Result<Vec<u8>>;
878    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
879        let this = &mut *self;
880        let res;
881        {
882            let client = this.inner.as_ref().ok_or(Error::AlreadyReceived)?;
883            res = client.poll_recv_response(&this.id, cx);
884        }
885
886        if let Poll::Ready(Ok(_)) = res {
887            let inner = this.inner.take().expect("CommandResponse polled after completion");
888            inner.wake_any();
889        }
890
891        res
892    }
893}
894
895impl FusedFuture for CommandResponse {
896    fn is_terminated(&self) -> bool {
897        self.inner.is_none()
898    }
899}
900
901impl Drop for CommandResponse {
902    fn drop(&mut self) {
903        if let Some(inner) = &self.inner {
904            inner.remove_response_interest(&self.id);
905            inner.wake_any();
906        }
907    }
908}
909
910#[derive(Debug)]
911struct PeerInner {
912    /// The signaling channel
913    signaling: Channel,
914
915    /// A map of transaction ids that have been sent but the response has not
916    /// been received and/or processed yet.
917    ///
918    /// Waiters are added with `add_response_waiter` and get removed when they are
919    /// polled or they are removed with `remove_waiter`
920    response_waiters: Mutex<Slab<ResponseWaiter>>,
921
922    /// A queue of requests that have been received and are waiting to
923    /// be responded to, along with the waker for the task that has
924    /// taken the request receiver (if it exists)
925    incoming_requests: Mutex<RequestQueue>,
926}
927
928impl PeerInner {
929    /// Add a response waiter, and return a id that can be used to send the
930    /// transaction.  Responses then can be received using poll_recv_response
931    fn add_response_waiter(&self) -> Result<TxLabel> {
932        let key = self.response_waiters.lock().insert(ResponseWaiter::WillPoll);
933        let id = TxLabel::try_from(key as u8);
934        if id.is_err() {
935            warn!("Transaction IDs are exhausted");
936            let _ = self.response_waiters.lock().remove(key);
937        }
938        id
939    }
940
941    /// When a waiter isn't interested in the response anymore, we need to just
942    /// throw it out.  This is called when the response future is dropped.
943    fn remove_response_interest(&self, id: &TxLabel) {
944        let mut lock = self.response_waiters.lock();
945        let idx = usize::from(id);
946        if lock[idx].is_received() {
947            let _ = lock.remove(idx);
948        } else {
949            lock[idx] = ResponseWaiter::Discard;
950        }
951    }
952
953    // Attempts to receive a new request by processing all packets on the socket.
954    // Resolves to an unprocessed request (header, body) if one was received.
955    // Resolves to an error if there was an error reading from the socket or if the peer
956    // disconnected.
957    fn poll_recv_request(&self, cx: &mut Context<'_>) -> Poll<Result<UnparsedRequest>> {
958        let is_closed = self.recv_all(cx)?;
959
960        let mut lock = self.incoming_requests.lock();
961
962        if let Some(request) = lock.queue.pop_front() {
963            Poll::Ready(Ok(request))
964        } else {
965            lock.listener = RequestListener::Some(cx.waker().clone());
966            if is_closed {
967                Poll::Ready(Err(Error::PeerDisconnected))
968            } else {
969                Poll::Pending
970            }
971        }
972    }
973
974    // Attempts to receive a response to a request by processing all packets on the socket.
975    // Resolves to the bytes in the response body if one was received.
976    // Resolves to an error if there was an error reading from the socket, if the peer
977    // disconnected, or if the |label| is not being waited on.
978    fn poll_recv_response(&self, label: &TxLabel, cx: &mut Context<'_>) -> Poll<Result<Vec<u8>>> {
979        let is_closed = self.recv_all(cx)?;
980
981        let mut waiters = self.response_waiters.lock();
982        let idx = usize::from(label);
983        // We expect() below because the label above came from an internally-created object,
984        // so the waiters should always exist in the map.
985        if waiters.get(idx).expect("Polled unregistered waiter").is_received() {
986            // We got our response.
987            let buf = waiters.remove(idx).unwrap_received();
988            Poll::Ready(Ok(buf))
989        } else {
990            // Set the waker to be notified when a response shows up.
991            *waiters.get_mut(idx).expect("Polled unregistered waiter") =
992                ResponseWaiter::Waiting(cx.waker().clone());
993
994            if is_closed {
995                Poll::Ready(Err(Error::PeerDisconnected))
996            } else {
997                Poll::Pending
998            }
999        }
1000    }
1001
1002    /// Poll for any packets on the signaling socket
1003    /// Returns whether the channel was closed, or an Error::PeerRead or Error::PeerWrite
1004    /// if there was a problem communicating on the socket.
1005    fn recv_all(&self, cx: &mut Context<'_>) -> Result<bool> {
1006        loop {
1007            let mut next_packet = Vec::new();
1008            let packet_size = match self.signaling.poll_datagram(cx, &mut next_packet) {
1009                Poll::Ready(Err(zx::Status::PEER_CLOSED)) => {
1010                    trace!("Signaling peer closed");
1011                    return Ok(true);
1012                }
1013                Poll::Ready(Err(e)) => return Err(Error::PeerRead(e)),
1014                Poll::Pending => return Ok(false),
1015                Poll::Ready(Ok(size)) => size,
1016            };
1017            if packet_size == 0 {
1018                continue;
1019            }
1020            // Detects General Reject condition and sends the response back.
1021            // On other headers with errors, sends BAD_HEADER to the peer
1022            // and attempts to continue.
1023            let header = match SignalingHeader::decode(next_packet.as_slice()) {
1024                Err(Error::InvalidSignalId(label, id)) => {
1025                    self.send_general_reject(label, id)?;
1026                    continue;
1027                }
1028                Err(_) => {
1029                    // Only possible other return is OutOfRange
1030                    // Returned only when the packet is too small, can't make a meaningful reject.
1031                    info!("received unrejectable message");
1032                    continue;
1033                }
1034                Ok(x) => x,
1035            };
1036            // Commands from the remote get translated into requests.
1037            if header.is_command() {
1038                let mut lock = self.incoming_requests.lock();
1039                let body = next_packet.split_off(header.encoded_len());
1040                lock.queue.push_back(UnparsedRequest::new(header, body));
1041                if let RequestListener::Some(ref waker) = lock.listener {
1042                    waker.wake_by_ref();
1043                }
1044            } else {
1045                // Should be a response to a command we sent
1046                let mut waiters = self.response_waiters.lock();
1047                let idx = usize::from(&header.label());
1048                if let Some(&ResponseWaiter::Discard) = waiters.get(idx) {
1049                    let _ = waiters.remove(idx);
1050                } else if let Some(entry) = waiters.get_mut(idx) {
1051                    let old_entry = mem::replace(entry, ResponseWaiter::Received(next_packet));
1052                    if let ResponseWaiter::Waiting(waker) = old_entry {
1053                        waker.wake();
1054                    }
1055                } else {
1056                    warn!("response for {:?} we did not send, dropping", header.label());
1057                }
1058                // Note: we drop any TxLabel response we are not waiting for
1059            }
1060        }
1061    }
1062
1063    // Wakes up an arbitrary task that has begun polling on the channel so that
1064    // it will call recv_all and be registered as the new channel reader.
1065    fn wake_any(&self) {
1066        // Try to wake up response waiters first, rather than the event listener.
1067        // The event listener is a stream, and so could be between poll_nexts,
1068        // Response waiters should always be actively polled once
1069        // they've begun being polled on a task.
1070        {
1071            let lock = self.response_waiters.lock();
1072            for (_, response_waiter) in lock.iter() {
1073                if let ResponseWaiter::Waiting(waker) = response_waiter {
1074                    waker.wake_by_ref();
1075                    return;
1076                }
1077            }
1078        }
1079        {
1080            let lock = self.incoming_requests.lock();
1081            if let RequestListener::Some(waker) = &lock.listener {
1082                waker.wake_by_ref();
1083                return;
1084            }
1085        }
1086    }
1087
1088    // Build and send a General Reject message (Section 8.18)
1089    fn send_general_reject(&self, label: TxLabel, invalid_signal_id: u8) -> Result<()> {
1090        // Build the packet ourselves rather than make SignalingHeader build an packet with an
1091        // invalid signal id.
1092        let packet: &[u8; 2] = &[u8::from(&label) << 4 | 0x01, invalid_signal_id & 0x3F];
1093        self.send_signal(packet)
1094    }
1095
1096    fn send_response(&self, label: TxLabel, signal: SignalIdentifier, params: &[u8]) -> Result<()> {
1097        let header = SignalingHeader::new(label, signal, SignalingMessageType::ResponseAccept);
1098        let mut packet = vec![0 as u8; header.encoded_len() + params.len()];
1099        header.encode(packet.as_mut_slice())?;
1100        packet[header.encoded_len()..].clone_from_slice(params);
1101        self.send_signal(&packet)
1102    }
1103
1104    fn send_reject(
1105        &self,
1106        label: TxLabel,
1107        signal: SignalIdentifier,
1108        error_code: ErrorCode,
1109    ) -> Result<()> {
1110        self.send_reject_params(label, signal, &[u8::from(&error_code)])
1111    }
1112
1113    fn send_reject_params(
1114        &self,
1115        label: TxLabel,
1116        signal: SignalIdentifier,
1117        params: &[u8],
1118    ) -> Result<()> {
1119        let header = SignalingHeader::new(label, signal, SignalingMessageType::ResponseReject);
1120        let mut packet = vec![0 as u8; header.encoded_len() + params.len()];
1121        header.encode(packet.as_mut_slice())?;
1122        packet[header.encoded_len()..].clone_from_slice(params);
1123        self.send_signal(&packet)
1124    }
1125
1126    fn send_signal(&self, data: &[u8]) -> Result<()> {
1127        let _ = self.signaling.write(data).map_err(|x| Error::PeerWrite(x))?;
1128        Ok(())
1129    }
1130}