1use core::pin::Pin;
6use futures::stream::{FusedStream, Stream, StreamExt};
7use std::sync::atomic::{AtomicBool, Ordering};
8use std::sync::Arc;
9use std::task::Poll;
10
11use bt_bap::types::{BroadcastAudioSourceEndpoint, BroadcastId};
12use bt_common::core::{AdvertisingSetId, PeriodicAdvertisingInterval};
13use bt_common::packet_encoding::Decodable;
14use bt_common::packet_encoding::Error as PacketError;
15use bt_common::PeerId;
16use bt_gatt::central::{AdvertisingDatum, ScanResult};
17
18use crate::assistant::{
19 DiscoveredBroadcastSources, Error, BASIC_AUDIO_ANNOUNCEMENT_SERVICE,
20 BROADCAST_AUDIO_ANNOUNCEMENT_SERVICE,
21};
22use crate::types::BroadcastSource;
23
24#[derive(Debug)]
25pub enum Event {
26 FoundBroadcastSource { peer: PeerId, source: BroadcastSource },
27 CouldNotParseAdvertisingData { peer: PeerId, error: PacketError },
28}
29
30pub struct EventStream<T: bt_gatt::GattTypes> {
34 scan_result_stream: Pin<Box<<T as bt_gatt::GattTypes>::ScanResultStream>>,
35 terminated: bool,
36
37 broadcast_sources: Arc<DiscoveredBroadcastSources>,
38 broadcast_source_scan_started: Arc<AtomicBool>,
39}
40
41impl<T: bt_gatt::GattTypes> EventStream<T> {
42 pub(crate) fn new(
43 scan_result_stream: T::ScanResultStream,
44 broadcast_sources: Arc<DiscoveredBroadcastSources>,
45 broadcast_source_scan_started: Arc<AtomicBool>,
46 ) -> Self {
47 Self {
48 scan_result_stream: Box::pin(scan_result_stream),
49 terminated: false,
50 broadcast_sources,
51 broadcast_source_scan_started,
52 }
53 }
54
55 fn try_into_broadcast_source(
59 scan_result: &ScanResult,
60 ) -> Result<Option<BroadcastSource>, PacketError> {
61 let mut source = None;
62 for datum in &scan_result.advertised {
63 let AdvertisingDatum::ServiceData(uuid, data) = datum else {
64 continue;
65 };
66 if *uuid == BROADCAST_AUDIO_ANNOUNCEMENT_SERVICE {
67 let bid = BroadcastId::decode(data.as_slice()).0?;
68 source.get_or_insert(BroadcastSource::default()).with_broadcast_id(bid);
69 } else if *uuid == BASIC_AUDIO_ANNOUNCEMENT_SERVICE {
70 let base = BroadcastAudioSourceEndpoint::decode(data.as_slice()).0?;
72 source.get_or_insert(BroadcastSource::default()).with_endpoint(base);
73 }
74 }
75 if let Some(src) = &mut source {
76 src.advertising_sid = scan_result.advertising_sid.map(AdvertisingSetId);
77 src.periodic_advertising_interval =
78 scan_result.periodic_advertising_interval.map(PeriodicAdvertisingInterval);
79 }
80 Ok(source)
81 }
82}
83
84impl<T: bt_gatt::GattTypes> Drop for EventStream<T> {
85 fn drop(&mut self) {
86 self.broadcast_source_scan_started.store(false, Ordering::Relaxed);
87 }
88}
89
90impl<T: bt_gatt::GattTypes> FusedStream for EventStream<T> {
91 fn is_terminated(&self) -> bool {
92 self.terminated
93 }
94}
95
96impl<T: bt_gatt::GattTypes> Stream for EventStream<T> {
97 type Item = Result<Event, Error>;
98
99 fn poll_next(
100 mut self: std::pin::Pin<&mut Self>,
101 cx: &mut std::task::Context<'_>,
102 ) -> Poll<Option<Self::Item>> {
103 if self.terminated {
104 return Poll::Ready(None);
105 }
106
107 loop {
110 let Some(Ok(scanned)) = futures::ready!(self.scan_result_stream.poll_next_unpin(cx))
111 else {
112 self.terminated = true;
113 self.broadcast_source_scan_started.store(false, Ordering::Relaxed);
114 return Poll::Ready(Some(Err(Error::CentralScanTerminated)));
115 };
116
117 let found_source = match Self::try_into_broadcast_source(&scanned) {
118 Err(error) => {
119 return Poll::Ready(Some(Ok(Event::CouldNotParseAdvertisingData {
120 peer: scanned.id,
121 error,
122 })));
123 }
124 Ok(None) => continue,
125 Ok(Some(source)) => source,
126 };
127
128 let (broadcast_source, changed) =
131 self.broadcast_sources.merge_broadcast_source_data(&scanned.id, &found_source);
132
133 if broadcast_source.into_add_source() && changed {
136 return Poll::Ready(Some(Ok(Event::FoundBroadcastSource {
137 peer: scanned.id,
138 source: broadcast_source,
139 })));
140 }
141 }
142 }
143}
144
145#[cfg(test)]
146mod tests {
147 use super::*;
148
149 use assert_matches::assert_matches;
150
151 use bt_common::core::{AddressType, AdvertisingSetId};
152 use bt_gatt::central::{AdvertisingDatum, PeerName};
153 use bt_gatt::test_utils::{FakeTypes, ScannedResultStream, ScannedResultStreamController};
154 use bt_gatt::types::Error as BtGattError;
155 use bt_gatt::types::GattError;
156
157 fn setup_stream() -> (EventStream<FakeTypes>, ScannedResultStreamController) {
158 let fake_scan_result_stream = ScannedResultStream::new();
159 let controller = fake_scan_result_stream.controller();
160 let broadcast_sources = DiscoveredBroadcastSources::new();
161 let broadcast_source_scan_started = Arc::new(AtomicBool::new(false));
162
163 (
164 EventStream::<FakeTypes>::new(
165 fake_scan_result_stream,
166 broadcast_sources,
167 broadcast_source_scan_started,
168 ),
169 controller,
170 )
171 }
172
173 #[test]
174 fn poll_found_broadcast_source_events() {
175 let (mut stream, scan_result_controller) = setup_stream();
176
177 let broadcast_source_pid = PeerId(1005);
179
180 scan_result_controller.add_scanned_result(Ok(ScanResult {
181 id: broadcast_source_pid,
182 connectable: true,
183 name: PeerName::Unknown,
184 advertised: vec![AdvertisingDatum::ServiceData(
185 BROADCAST_AUDIO_ANNOUNCEMENT_SERVICE,
186 vec![0x01, 0x02, 0x03],
187 )],
188 advertising_sid: Some(0),
189 periodic_advertising_interval: None,
190 }));
191
192 let mut noop_cx = futures::task::Context::from_waker(futures::task::noop_waker_ref());
195 assert!(stream.poll_next_unpin(&mut noop_cx).is_pending());
196
197 let _ = stream.broadcast_sources.merge_broadcast_source_data(
202 &broadcast_source_pid,
203 &BroadcastSource::default()
204 .with_address([1, 2, 3, 4, 5, 6])
205 .with_address_type(AddressType::Public)
206 .with_advertising_sid(AdvertisingSetId(1)),
207 );
208
209 #[rustfmt::skip]
212 let base_data = vec![
213 0x10, 0x20, 0x30, 0x02, 0x01, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x01, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x03, 0x02, 0x05,
222 0x08, ];
225
226 scan_result_controller.add_scanned_result(Ok(ScanResult {
227 id: broadcast_source_pid,
228 connectable: true,
229 name: PeerName::Unknown,
230 advertised: vec![AdvertisingDatum::ServiceData(
231 BASIC_AUDIO_ANNOUNCEMENT_SERVICE,
232 base_data.clone(),
233 )],
234 advertising_sid: Some(1),
235 periodic_advertising_interval: Some(0x0100),
236 }));
237
238 let Poll::Ready(Some(Ok(event))) = stream.poll_next_unpin(&mut noop_cx) else {
241 panic!("should have received event");
242 };
243 assert_matches!(event, Event::FoundBroadcastSource{peer, source} => {
244 assert_eq!(peer, broadcast_source_pid);
245 assert_eq!(source.advertising_sid, Some(AdvertisingSetId(1)));
246 assert_eq!(source.periodic_advertising_interval, Some(PeriodicAdvertisingInterval(0x0100)));
247 });
248
249 assert!(stream.poll_next_unpin(&mut noop_cx).is_pending());
250
251 scan_result_controller.add_scanned_result(Ok(ScanResult {
253 id: broadcast_source_pid,
254 connectable: true,
255 name: PeerName::Unknown,
256 advertised: vec![AdvertisingDatum::ServiceData(
257 BASIC_AUDIO_ANNOUNCEMENT_SERVICE,
258 base_data.clone(),
259 )],
260 advertising_sid: Some(1),
261 periodic_advertising_interval: Some(0x0100),
262 }));
263
264 assert!(stream.poll_next_unpin(&mut noop_cx).is_pending());
267 }
268
269 #[test]
270 fn central_scan_stream_terminates() {
271 let (mut stream, scan_result_controller) = setup_stream();
272
273 scan_result_controller.add_scanned_result(Err(BtGattError::Gatt(GattError::InvalidPdu)));
275
276 let mut noop_cx = futures::task::Context::from_waker(futures::task::noop_waker_ref());
277 match stream.poll_next_unpin(&mut noop_cx) {
278 Poll::Ready(Some(Err(e))) => assert_matches!(e, Error::CentralScanTerminated),
279 _ => panic!("should have received central scan terminated error"),
280 }
281
282 assert_matches!(stream.poll_next_unpin(&mut noop_cx), Poll::Ready(None));
284 assert_matches!(stream.is_terminated(), true);
285 }
286
287 #[test]
288 fn poll_processes_multiple_results_eagerly() {
289 let (mut stream, scan_result_controller) = setup_stream();
290 let mut noop_cx = futures::task::Context::from_waker(futures::task::noop_waker_ref());
291
292 let broadcast_source_pid = PeerId(1005);
293
294 #[rustfmt::skip]
295 let base_data = vec![
296 0x10, 0x20, 0x30, 0x01, 0x01, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, ];
302
303 scan_result_controller.add_scanned_result(Ok(ScanResult {
305 id: PeerId(1),
306 connectable: true,
307 name: PeerName::Unknown,
308 advertised: vec![],
309 advertising_sid: Some(0),
310 periodic_advertising_interval: None,
311 }));
312 scan_result_controller.add_scanned_result(Ok(ScanResult {
314 id: broadcast_source_pid,
315 connectable: true,
316 name: PeerName::Unknown,
317 advertised: vec![AdvertisingDatum::ServiceData(
318 BROADCAST_AUDIO_ANNOUNCEMENT_SERVICE,
319 vec![0x01, 0x02, 0x03],
320 )],
321 advertising_sid: Some(1),
322 periodic_advertising_interval: None,
323 }));
324 scan_result_controller.add_scanned_result(Ok(ScanResult {
326 id: broadcast_source_pid,
327 connectable: true,
328 name: PeerName::Unknown,
329 advertised: vec![AdvertisingDatum::ServiceData(
330 BASIC_AUDIO_ANNOUNCEMENT_SERVICE,
331 base_data.clone(),
332 )],
333 advertising_sid: Some(1),
334 periodic_advertising_interval: None,
335 }));
336
337 let poll_result = stream.poll_next_unpin(&mut noop_cx);
340 let Poll::Ready(Some(Ok(event))) = poll_result else {
341 panic!("should have received event, but got {:?}", poll_result);
342 };
343 assert_matches!(event, Event::FoundBroadcastSource{peer, ..} => {
344 assert_eq!(peer, broadcast_source_pid);
345 });
346
347 assert!(stream.poll_next_unpin(&mut noop_cx).is_pending());
349 }
350}