bt_broadcast_assistant/
assistant.rs

1// Copyright 2023 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use parking_lot::Mutex;
6use std::collections::HashMap;
7use std::sync::atomic::{AtomicBool, Ordering};
8use std::sync::Arc;
9use thiserror::Error;
10
11use bt_bap::types::BroadcastId;
12use bt_bass::client::error::Error as BassClientError;
13use bt_bass::client::BroadcastAudioScanServiceClient;
14use bt_common::{PeerId, Uuid};
15use bt_gatt::central::*;
16use bt_gatt::client::PeerServiceHandle;
17use bt_gatt::types::Error as GattError;
18use bt_gatt::Client;
19
20pub mod event;
21use event::*;
22pub mod peer;
23pub use peer::Peer;
24
25use crate::types::*;
26
27pub const BROADCAST_AUDIO_SCAN_SERVICE: Uuid = Uuid::from_u16(0x184F);
28pub const BASIC_AUDIO_ANNOUNCEMENT_SERVICE: Uuid = Uuid::from_u16(0x1851);
29pub const BROADCAST_AUDIO_ANNOUNCEMENT_SERVICE: Uuid = Uuid::from_u16(0x1852);
30
31#[derive(Debug, Error)]
32pub enum Error {
33    #[error("GATT operation error: {0:?}")]
34    Gatt(#[from] GattError),
35
36    #[error("Broadcast Audio Scan Service client error at peer ({0}): {1:?}")]
37    BassClient(PeerId, BassClientError),
38
39    #[error("Not connected to Broadcast Audio Scan Service at peer ({0})")]
40    NotConnectedToBass(PeerId),
41
42    #[error("Central scanning terminated unexpectedly")]
43    CentralScanTerminated,
44
45    #[error("Failed to connect to service ({1}) at peer ({0})")]
46    ConnectionFailure(PeerId, Uuid),
47
48    #[error("Broadcast Assistant was already started previously. It cannot be started twice")]
49    AlreadyStarted,
50
51    #[error("Failed due to error: {0}")]
52    Generic(String),
53}
54
55/// Contains information about the currently-known broadcast
56/// sources and the peers they were found on
57#[derive(Debug)]
58pub(crate) struct DiscoveredBroadcastSources(Mutex<HashMap<PeerId, BroadcastSource>>);
59
60impl DiscoveredBroadcastSources {
61    /// Creates a shareable instance of `DiscoveredBroadcastSources`.
62    pub fn new() -> Arc<Self> {
63        Arc::new(Self(Mutex::new(HashMap::new())))
64    }
65
66    /// Merges the broadcast source data with existing broadcast source data.
67    /// Returns the copy of the broadcast source data after the merge and
68    /// indicates whether it has changed from before or not.
69    pub(crate) fn merge_broadcast_source_data(
70        &self,
71        peer_id: &PeerId,
72        data: &BroadcastSource,
73    ) -> (BroadcastSource, bool) {
74        let mut lock = self.0.lock();
75        let source = lock.entry(*peer_id).or_default();
76        let before = source.clone();
77
78        source.merge(data);
79        let after = source.clone();
80        let changed = before != after;
81
82        (after, changed)
83    }
84
85    /// Get a BroadcastSource from a peer id.
86    fn get_by_peer_id(&self, peer_id: &PeerId) -> Option<BroadcastSource> {
87        let lock = self.0.lock();
88        lock.get(&peer_id).clone().map(|source| source.clone())
89    }
90
91    /// Get a BroadcastSource from associated broadcast id.
92    fn get_by_broadcast_id(&self, broadcast_id: &BroadcastId) -> Option<BroadcastSource> {
93        let lock = self.0.lock();
94        let info = lock.iter().find(|(&_k, &ref v)| v.broadcast_id == Some(*broadcast_id));
95        match info {
96            Some((&_peer_id, &ref broadcast_source)) => Some(broadcast_source.clone()),
97            None => None,
98        }
99    }
100}
101
102pub struct BroadcastAssistant<T: bt_gatt::GattTypes> {
103    central: T::Central,
104    broadcast_sources: Arc<DiscoveredBroadcastSources>,
105    broadcast_source_scan_started: Arc<AtomicBool>,
106}
107
108impl<T: bt_gatt::GattTypes + 'static> BroadcastAssistant<T> {
109    // Creates a broadcast assistant and sets it up to be ready
110    // for broadcast source scanning. Clients must use the `start`
111    // method to poll the event stream for scan results.
112    pub fn new(central: T::Central) -> Self {
113        Self {
114            central,
115            broadcast_sources: DiscoveredBroadcastSources::new(),
116            broadcast_source_scan_started: Arc::new(AtomicBool::new(false)),
117        }
118    }
119
120    /// List of scan filters for advertisement data Broadcast Assistant should
121    /// look for, which are:
122    /// - Service data with Broadcast Audio Announcement Service UUID from
123    ///   Broadcast Sources (see BAP spec v1.0.1 Section 3.7.2.1 for details)
124    // TODO(b/308481381): define filter for finding broadcast sink.
125    fn scan_filters() -> Vec<ScanFilter> {
126        vec![Filter::HasServiceData(BROADCAST_AUDIO_ANNOUNCEMENT_SERVICE).into()]
127    }
128
129    /// Start broadcast assistant. Returns EventStream that the upper layer can
130    /// poll. Upper layer can call methods on BroadcastAssistant based on the
131    /// events it sees.
132    pub fn start(&mut self) -> Result<EventStream<T>, Error> {
133        if self.is_started() {
134            return Err(Error::AlreadyStarted);
135        }
136        let scan_result_stream = self.central.scan(&Self::scan_filters());
137        self.broadcast_source_scan_started.store(true, Ordering::Relaxed);
138        Ok(EventStream::<T>::new(
139            scan_result_stream,
140            self.broadcast_sources.clone(),
141            self.broadcast_source_scan_started.clone(),
142        ))
143    }
144
145    /// Returns whether or not Broadcast Assistant has started.
146    fn is_started(&self) -> bool {
147        self.broadcast_source_scan_started.load(Ordering::Relaxed)
148    }
149
150    pub fn scan_for_scan_delegators(&mut self) -> Result<T::ScanResultStream, Error> {
151        if self.is_started() {
152            return Err(Error::Generic(format!(
153                "Cannot scan for scan delegators while scanning for broadcast sources"
154            )));
155        }
156        // Scan for service data with Broadcast Audio Scan Service UUID to look
157        // for Broadcast Sink collocated with the Scan Delegator (see BAP spec v1.0.1
158        // Section 3.9.2 for details).
159        Ok(self.central.scan(&vec![Filter::HasServiceData(BROADCAST_AUDIO_SCAN_SERVICE).into()]))
160    }
161
162    pub async fn connect_to_scan_delegator(&self, peer_id: PeerId) -> Result<Peer<T>, Error>
163    where
164        <T as bt_gatt::GattTypes>::NotificationStream: std::marker::Send,
165    {
166        let client = self.central.connect(peer_id).await?;
167        let service_handles = client.find_service(BROADCAST_AUDIO_SCAN_SERVICE).await?;
168
169        for handle in service_handles {
170            if handle.uuid() != BROADCAST_AUDIO_SCAN_SERVICE || !handle.is_primary() {
171                continue;
172            }
173            let service = handle.connect().await?;
174            let bass = BroadcastAudioScanServiceClient::<T>::create(service)
175                .await
176                .map_err(|e| Error::BassClient(peer_id, e))?;
177
178            let connected_peer =
179                Peer::<T>::new(peer_id, client, bass, self.broadcast_sources.clone());
180            return Ok(connected_peer);
181        }
182        Err(Error::ConnectionFailure(peer_id, BROADCAST_AUDIO_SCAN_SERVICE))
183    }
184
185    // Manually adds broadcast source information for debugging purposes.
186    #[cfg(any(test, feature = "debug"))]
187    pub fn force_discover_broadcast_source(
188        &self,
189        peer_id: PeerId,
190        address: [u8; 6],
191        address_type: bt_common::core::AddressType,
192        advertising_sid: bt_common::core::AdvertisingSetId,
193    ) -> Result<BroadcastSource, Error> {
194        let broadcast_source = BroadcastSource {
195            address: Some(address),
196            address_type: Some(address_type),
197            advertising_sid: Some(advertising_sid),
198            broadcast_id: None,
199            pa_interval: None,
200            endpoint: None,
201        };
202
203        Ok(self.broadcast_sources.merge_broadcast_source_data(&peer_id, &broadcast_source).0)
204    }
205
206    // Manually adds broadcast source information for debugging purposes.
207    #[cfg(any(test, feature = "debug"))]
208    pub fn force_discover_broadcast_source_metadata(
209        &self,
210        peer_id: PeerId,
211        big_metadata: Vec<Vec<bt_common::generic_audio::metadata_ltv::Metadata>>,
212    ) -> Result<BroadcastSource, Error> {
213        use bt_bap::types::{BroadcastAudioSourceEndpoint, BroadcastIsochronousGroup};
214        use bt_common::core::CodecId;
215
216        let mut big = Vec::new();
217        for metadata in big_metadata {
218            let group = BroadcastIsochronousGroup {
219                codec_id: CodecId::Assigned(bt_common::core::CodingFormat::ALawLog), // mock.
220                codec_specific_configs: vec![],
221                metadata,
222                bis: vec![],
223            };
224            big.push(group);
225        }
226
227        let endpoint = BroadcastAudioSourceEndpoint { presentation_delay_ms: 0, big };
228
229        let broadcast_source = BroadcastSource {
230            address: None,
231            address_type: None,
232            advertising_sid: None,
233            broadcast_id: None,
234            pa_interval: None,
235            endpoint: Some(endpoint),
236        };
237
238        Ok(self.broadcast_sources.merge_broadcast_source_data(&peer_id, &broadcast_source).0)
239    }
240
241    // Gets the broadcast sources currently known by the broadcast
242    // assistant.
243    pub fn known_broadcast_sources(&self) -> std::collections::HashMap<PeerId, BroadcastSource> {
244        let lock = self.broadcast_sources.0.lock();
245        let mut m = HashMap::new();
246        for (pid, source) in lock.iter() {
247            m.insert(*pid, source.clone());
248        }
249        m
250    }
251}
252
253#[cfg(test)]
254mod tests {
255    use super::*;
256
257    use futures::{pin_mut, FutureExt};
258    use std::task::Poll;
259
260    use bt_bap::types::*;
261    use bt_common::core::{AddressType, AdvertisingSetId};
262    use bt_common::generic_audio::metadata_ltv::Metadata;
263    use bt_gatt::test_utils::{FakeCentral, FakeClient, FakeTypes};
264
265    use crate::assistant::peer::tests::fake_bass_service;
266
267    #[test]
268    fn merge_broadcast_source() {
269        let discovered = DiscoveredBroadcastSources::new();
270        let bid = BroadcastId::try_from(1001).unwrap();
271        let (bs, changed) = discovered.merge_broadcast_source_data(
272            &PeerId(1001),
273            &BroadcastSource::default()
274                .with_address([1, 2, 3, 4, 5, 6])
275                .with_address_type(AddressType::Public)
276                .with_advertising_sid(AdvertisingSetId(1))
277                .with_broadcast_id(bid),
278        );
279        assert!(changed);
280        assert_eq!(
281            bs,
282            BroadcastSource {
283                address: Some([1, 2, 3, 4, 5, 6]),
284                address_type: Some(AddressType::Public),
285                advertising_sid: Some(AdvertisingSetId(1)),
286                broadcast_id: Some(bid),
287                pa_interval: None,
288                endpoint: None,
289            }
290        );
291
292        let (bs, changed) = discovered.merge_broadcast_source_data(
293            &PeerId(1001),
294            &BroadcastSource::default().with_address_type(AddressType::Random).with_endpoint(
295                BroadcastAudioSourceEndpoint { presentation_delay_ms: 32, big: vec![] },
296            ),
297        );
298        assert!(changed);
299        assert_eq!(
300            bs,
301            BroadcastSource {
302                address: Some([1, 2, 3, 4, 5, 6]),
303                address_type: Some(AddressType::Random),
304                advertising_sid: Some(AdvertisingSetId(1)),
305                broadcast_id: Some(bid),
306                pa_interval: None,
307                endpoint: Some(BroadcastAudioSourceEndpoint {
308                    presentation_delay_ms: 32,
309                    big: vec![]
310                }),
311            }
312        );
313
314        let (_, changed) = discovered.merge_broadcast_source_data(
315            &PeerId(1001),
316            &BroadcastSource::default().with_address_type(AddressType::Random).with_endpoint(
317                BroadcastAudioSourceEndpoint { presentation_delay_ms: 32, big: vec![] },
318            ),
319        );
320        assert!(!changed);
321    }
322
323    #[test]
324    fn start_stream() {
325        let mut assistant = BroadcastAssistant::<FakeTypes>::new(FakeCentral::new());
326        let stream = assistant.start().expect("can start stream");
327
328        // Stream can only be started once.
329        assert!(assistant.is_started());
330        assert!(assistant.start().is_err());
331
332        // After the stream is dropped, it can be started again.
333        drop(stream);
334        assert!(!assistant.is_started());
335        assert!(assistant.start().is_ok());
336    }
337
338    #[test]
339    fn connect_to_scan_delegator() {
340        // Set up fake GATT related objects.
341        let mut central = FakeCentral::new();
342        let mut client = FakeClient::new();
343        central.add_client(PeerId(1004), client.clone());
344        let service = fake_bass_service();
345        client.add_service(BROADCAST_AUDIO_SCAN_SERVICE, true, service.clone());
346
347        let mut noop_cx = futures::task::Context::from_waker(futures::task::noop_waker_ref());
348        let assistant = BroadcastAssistant::<FakeTypes>::new(central);
349        let conn_fut = assistant.connect_to_scan_delegator(PeerId(1004));
350        pin_mut!(conn_fut);
351        let polled = conn_fut.poll_unpin(&mut noop_cx);
352        let Poll::Ready(res) = polled else {
353            panic!("should be ready");
354        };
355        let _ = res.expect("should be ok");
356    }
357
358    #[test]
359    fn force_discover_broadcast_source_test() {
360        let assistant = BroadcastAssistant::<FakeTypes>::new(FakeCentral::new());
361        let peer_id = PeerId(1);
362        let address = [1, 2, 3, 4, 5, 6];
363        let address_type = AddressType::Public;
364        let sid = AdvertisingSetId(1);
365
366        let source =
367            assistant.force_discover_broadcast_source(peer_id, address, address_type, sid).unwrap();
368
369        assert_eq!(source.address, Some(address));
370        assert_eq!(source.address_type, Some(address_type));
371        assert_eq!(source.advertising_sid, Some(sid));
372    }
373
374    #[test]
375    fn force_discover_broadcast_source_metadata_test() {
376        let assistant = BroadcastAssistant::<FakeTypes>::new(FakeCentral::new());
377        let peer_id = PeerId(1);
378        let metadata = vec![vec![Metadata::BroadcastAudioImmediateRenderingFlag]];
379
380        let source =
381            assistant.force_discover_broadcast_source_metadata(peer_id, metadata.clone()).unwrap();
382
383        let endpoint = source.endpoint.unwrap();
384        assert_eq!(endpoint.big.len(), 1);
385        assert_eq!(endpoint.big[0].metadata, metadata[0]);
386    }
387}