bt_broadcast_assistant/assistant/
event.rs1use core::pin::Pin;
6use futures::stream::{FusedStream, Stream, StreamExt};
7use std::sync::atomic::{AtomicBool, Ordering};
8use std::sync::Arc;
9use std::task::Poll;
10
11use bt_bap::types::{BroadcastAudioSourceEndpoint, BroadcastId};
12use bt_common::core::AdvertisingSetId;
13use bt_common::packet_encoding::Decodable;
14use bt_common::packet_encoding::Error as PacketError;
15use bt_common::PeerId;
16use bt_gatt::central::{AdvertisingDatum, ScanResult};
17
18use crate::assistant::{
19 DiscoveredBroadcastSources, Error, BASIC_AUDIO_ANNOUNCEMENT_SERVICE,
20 BROADCAST_AUDIO_ANNOUNCEMENT_SERVICE,
21};
22use crate::types::BroadcastSource;
23
24#[derive(Debug)]
25pub enum Event {
26 FoundBroadcastSource { peer: PeerId, source: BroadcastSource },
27 CouldNotParseAdvertisingData { peer: PeerId, error: PacketError },
28}
29
30pub struct EventStream<T: bt_gatt::GattTypes> {
34 scan_result_stream: Pin<Box<<T as bt_gatt::GattTypes>::ScanResultStream>>,
35 terminated: bool,
36
37 broadcast_sources: Arc<DiscoveredBroadcastSources>,
38 broadcast_source_scan_started: Arc<AtomicBool>,
39}
40
41impl<T: bt_gatt::GattTypes> EventStream<T> {
42 pub(crate) fn new(
43 scan_result_stream: T::ScanResultStream,
44 broadcast_sources: Arc<DiscoveredBroadcastSources>,
45 broadcast_source_scan_started: Arc<AtomicBool>,
46 ) -> Self {
47 Self {
48 scan_result_stream: Box::pin(scan_result_stream),
49 terminated: false,
50 broadcast_sources,
51 broadcast_source_scan_started,
52 }
53 }
54
55 fn try_into_broadcast_source(
59 scan_result: &ScanResult,
60 ) -> Result<Option<BroadcastSource>, PacketError> {
61 let mut source = None;
62 for datum in &scan_result.advertised {
63 let AdvertisingDatum::ServiceData(uuid, data) = datum else {
64 continue;
65 };
66 if *uuid == BROADCAST_AUDIO_ANNOUNCEMENT_SERVICE {
67 let bid = BroadcastId::decode(data.as_slice()).0?;
68 source.get_or_insert(BroadcastSource::default()).with_broadcast_id(bid);
69 } else if *uuid == BASIC_AUDIO_ANNOUNCEMENT_SERVICE {
70 let base = BroadcastAudioSourceEndpoint::decode(data.as_slice()).0?;
72 source.get_or_insert(BroadcastSource::default()).with_endpoint(base);
73 }
74 }
75 if let Some(src) = &mut source {
76 src.advertising_sid = Some(AdvertisingSetId(scan_result.advertising_sid));
77 }
78 Ok(source)
79 }
80}
81
82impl<T: bt_gatt::GattTypes> Drop for EventStream<T> {
83 fn drop(&mut self) {
84 self.broadcast_source_scan_started.store(false, Ordering::Relaxed);
85 }
86}
87
88impl<T: bt_gatt::GattTypes> FusedStream for EventStream<T> {
89 fn is_terminated(&self) -> bool {
90 self.terminated
91 }
92}
93
94impl<T: bt_gatt::GattTypes> Stream for EventStream<T> {
95 type Item = Result<Event, Error>;
96
97 fn poll_next(
98 mut self: std::pin::Pin<&mut Self>,
99 cx: &mut std::task::Context<'_>,
100 ) -> Poll<Option<Self::Item>> {
101 if self.terminated {
102 return Poll::Ready(None);
103 }
104
105 match futures::ready!(self.scan_result_stream.poll_next_unpin(cx)) {
108 Some(Ok(scanned)) => {
109 match Self::try_into_broadcast_source(&scanned) {
110 Err(e) => {
111 return Poll::Ready(Some(Ok(Event::CouldNotParseAdvertisingData {
112 peer: scanned.id,
113 error: e,
114 })));
115 }
116 Ok(Some(found_source)) => {
117 let (broadcast_source, changed) = self
120 .broadcast_sources
121 .merge_broadcast_source_data(&scanned.id, &found_source);
122
123 if broadcast_source.into_add_source() && changed {
126 return Poll::Ready(Some(Ok(Event::FoundBroadcastSource {
127 peer: scanned.id,
128 source: broadcast_source,
129 })));
130 }
131
132 Poll::Pending
133 }
134 Ok(None) => Poll::Pending,
135 }
136 }
137 None | Some(Err(_)) => {
138 self.terminated = true;
139 self.broadcast_source_scan_started.store(false, Ordering::Relaxed);
140 Poll::Ready(Some(Err(Error::CentralScanTerminated)))
141 }
142 }
143 }
144}
145
146#[cfg(test)]
147mod tests {
148 use super::*;
149
150 use assert_matches::assert_matches;
151
152 use bt_common::core::{AddressType, AdvertisingSetId};
153 use bt_gatt::central::{AdvertisingDatum, PeerName};
154 use bt_gatt::test_utils::{FakeTypes, ScannedResultStream};
155 use bt_gatt::types::Error as BtGattError;
156 use bt_gatt::types::GattError;
157
158 fn setup_stream() -> (EventStream<FakeTypes>, ScannedResultStream) {
159 let fake_scan_result_stream = ScannedResultStream::new();
160 let broadcast_sources = DiscoveredBroadcastSources::new();
161 let broadcast_source_scan_started = Arc::new(AtomicBool::new(false));
162
163 (
164 EventStream::<FakeTypes>::new(
165 fake_scan_result_stream.clone(),
166 broadcast_sources,
167 broadcast_source_scan_started,
168 ),
169 fake_scan_result_stream,
170 )
171 }
172
173 #[test]
174 fn poll_found_broadcast_source_events() {
175 let (mut stream, mut scan_result_stream) = setup_stream();
176
177 let broadcast_source_pid = PeerId(1005);
179
180 scan_result_stream.set_scanned_result(Ok(ScanResult {
181 id: broadcast_source_pid,
182 connectable: true,
183 name: PeerName::Unknown,
184 advertised: vec![AdvertisingDatum::ServiceData(
185 BROADCAST_AUDIO_ANNOUNCEMENT_SERVICE,
186 vec![0x01, 0x02, 0x03],
187 )],
188 advertising_sid: 0,
189 }));
190
191 let mut noop_cx = futures::task::Context::from_waker(futures::task::noop_waker_ref());
194 assert!(stream.poll_next_unpin(&mut noop_cx).is_pending());
195
196 let _ = stream.broadcast_sources.merge_broadcast_source_data(
201 &broadcast_source_pid,
202 &BroadcastSource::default()
203 .with_address([1, 2, 3, 4, 5, 6])
204 .with_address_type(AddressType::Public)
205 .with_advertising_sid(AdvertisingSetId(1)),
206 );
207
208 #[rustfmt::skip]
211 let base_data = vec![
212 0x10, 0x20, 0x30, 0x02, 0x01, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x01, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x03, 0x02, 0x05,
221 0x08, ];
224
225 scan_result_stream.set_scanned_result(Ok(ScanResult {
226 id: broadcast_source_pid,
227 connectable: true,
228 name: PeerName::Unknown,
229 advertised: vec![AdvertisingDatum::ServiceData(
230 BASIC_AUDIO_ANNOUNCEMENT_SERVICE,
231 base_data.clone(),
232 )],
233 advertising_sid: 1,
234 }));
235
236 let Poll::Ready(Some(Ok(event))) = stream.poll_next_unpin(&mut noop_cx) else {
239 panic!("should have received event");
240 };
241 assert_matches!(event, Event::FoundBroadcastSource{peer, source} => {
242 assert_eq!(peer, broadcast_source_pid);
243 assert_eq!(source.advertising_sid, Some(AdvertisingSetId(1)));
244 });
245
246 assert!(stream.poll_next_unpin(&mut noop_cx).is_pending());
247
248 scan_result_stream.set_scanned_result(Ok(ScanResult {
250 id: broadcast_source_pid,
251 connectable: true,
252 name: PeerName::Unknown,
253 advertised: vec![AdvertisingDatum::ServiceData(
254 BASIC_AUDIO_ANNOUNCEMENT_SERVICE,
255 base_data.clone(),
256 )],
257 advertising_sid: 1,
258 }));
259
260 assert!(stream.poll_next_unpin(&mut noop_cx).is_pending());
263 }
264
265 #[test]
266 fn central_scan_stream_terminates() {
267 let (mut stream, mut scan_result_stream) = setup_stream();
268
269 scan_result_stream.set_scanned_result(Err(BtGattError::Gatt(GattError::InvalidPdu)));
271
272 let mut noop_cx = futures::task::Context::from_waker(futures::task::noop_waker_ref());
273 match stream.poll_next_unpin(&mut noop_cx) {
274 Poll::Ready(Some(Err(e))) => assert_matches!(e, Error::CentralScanTerminated),
275 _ => panic!("should have received central scan terminated error"),
276 }
277
278 assert_matches!(stream.poll_next_unpin(&mut noop_cx), Poll::Ready(None));
280 assert_matches!(stream.is_terminated(), true);
281 }
282}