1use crate::device::Info as DeviceInfo;
6use crate::sigproc::{Element, ElementState, Topology};
7use anyhow::{anyhow, Context, Error};
8use async_utils::event::Event as AsyncEvent;
9use async_utils::hanging_get::client::HangingGetStream;
10use fidl::endpoints::create_proxy;
11use fidl_fuchsia_audio_device as fadevice;
12use fuchsia_async::Task;
13use futures::channel::mpsc::{self, UnboundedReceiver, UnboundedSender};
14use futures::lock::Mutex;
15use futures::StreamExt;
16use log::error;
17use std::collections::{BTreeMap, BTreeSet};
18use std::sync::Arc;
19use zx_status::Status;
20
21#[derive(Debug, Clone)]
22pub enum DeviceEvent {
23 Added(Box<DeviceInfo>),
25 Removed(fadevice::TokenId),
27}
28
29pub type DeviceEventSender = UnboundedSender<DeviceEvent>;
30pub type DeviceEventReceiver = UnboundedReceiver<DeviceEvent>;
31
32pub struct Registry {
33 proxy: fadevice::RegistryProxy,
34 devices: Arc<Mutex<BTreeMap<fadevice::TokenId, DeviceInfo>>>,
35 devices_initialized: AsyncEvent,
36 event_senders: Arc<Mutex<Vec<DeviceEventSender>>>,
37 _watch_devices_task: Task<()>,
38}
39
40impl Registry {
41 pub fn new(proxy: fadevice::RegistryProxy) -> Self {
42 let devices = Arc::new(Mutex::new(BTreeMap::new()));
43 let devices_initialized = AsyncEvent::new();
44 let event_senders = Arc::new(Mutex::new(vec![]));
45
46 let watch_devices_task = Task::spawn({
47 let proxy = proxy.clone();
48 let devices = devices.clone();
49 let devices_initialized = devices_initialized.clone();
50 let event_senders = event_senders.clone();
51 async {
52 if let Err(err) =
53 watch_devices(proxy, devices, devices_initialized, event_senders).await
54 {
55 error!(err:%; "Failed to watch Registry devices");
56 }
57 }
58 });
59
60 Self {
61 proxy,
62 devices,
63 devices_initialized,
64 event_senders,
65 _watch_devices_task: watch_devices_task,
66 }
67 }
68
69 pub async fn device_info(&self, token_id: fadevice::TokenId) -> Option<DeviceInfo> {
73 self.devices_initialized.wait().await;
74 self.devices.lock().await.get(&token_id).cloned()
75 }
76
77 pub async fn device_infos(&self) -> BTreeMap<fadevice::TokenId, DeviceInfo> {
79 self.devices_initialized.wait().await;
80 self.devices.lock().await.clone()
81 }
82
83 pub async fn observe(&self, token_id: fadevice::TokenId) -> Result<RegistryDevice, Error> {
87 self.devices_initialized.wait().await;
88
89 let info = self
90 .devices
91 .lock()
92 .await
93 .get(&token_id)
94 .cloned()
95 .ok_or_else(|| anyhow!("Device with ID {} does not exist", token_id))?;
96
97 let (observer_proxy, observer_server) = create_proxy::<fadevice::ObserverMarker>();
98
99 let _ = self
100 .proxy
101 .create_observer(fadevice::RegistryCreateObserverRequest {
102 token_id: Some(token_id),
103 observer_server: Some(observer_server),
104 ..Default::default()
105 })
106 .await
107 .context("Failed to call CreateObserver")?
108 .map_err(|err| anyhow!("failed to create device observer: {:?}", err))?;
109
110 Ok(RegistryDevice::new(info, observer_proxy))
111 }
112
113 pub async fn subscribe(&self) -> DeviceEventReceiver {
115 let (sender, receiver) = mpsc::unbounded::<DeviceEvent>();
116 self.event_senders.lock().await.push(sender);
117 receiver
118 }
119}
120
121async fn watch_devices(
127 proxy: fadevice::RegistryProxy,
128 devices: Arc<Mutex<BTreeMap<fadevice::TokenId, DeviceInfo>>>,
129 devices_initialized: AsyncEvent,
130 event_senders: Arc<Mutex<Vec<DeviceEventSender>>>,
131) -> Result<(), Error> {
132 let mut devices_initialized = Some(devices_initialized);
133
134 let mut devices_added_stream =
135 HangingGetStream::new(proxy.clone(), fadevice::RegistryProxy::watch_devices_added);
136 let mut device_removed_stream =
137 HangingGetStream::new(proxy, fadevice::RegistryProxy::watch_device_removed);
138
139 loop {
140 futures::select! {
141 added = devices_added_stream.select_next_some() => {
142 let response = added
143 .context("failed to call WatchDevicesAdded")?
144 .map_err(|err| anyhow!("failed to watch for added devices: {:?}", err))?;
145 let added_devices = response.devices.ok_or_else(|| anyhow!("missing devices"))?;
146
147 let mut devices = devices.lock().await;
148 let mut event_senders = event_senders.lock().await;
149
150 for new_device in added_devices.into_iter() {
151 let token_id = new_device.token_id.ok_or_else(|| anyhow!("device info missing token_id"))?;
152 let device_info = DeviceInfo::from(new_device);
153 for sender in event_senders.iter_mut() {
154 let _ = sender.unbounded_send(DeviceEvent::Added(Box::new(device_info.clone())));
155 }
156 let _ = devices.insert(token_id, device_info);
157 }
158
159 if let Some(devices_initialized) = devices_initialized.take() {
160 devices_initialized.signal();
161 }
162 },
163 removed = device_removed_stream.select_next_some() => {
164 let response = removed
165 .context("failed to call WatchDeviceRemoved")?
166 .map_err(|err| anyhow!("failed to watch for removed device: {:?}", err))?;
167 let token_id = response.token_id.ok_or_else(|| anyhow!("missing token_id"))?;
168 let mut devices = devices.lock().await;
169 let _ = devices.remove(&token_id);
170 for sender in event_senders.lock().await.iter_mut() {
171 let _ = sender.unbounded_send(DeviceEvent::Removed(token_id));
172 }
173 }
174 }
175 }
176}
177
178pub struct RegistryDevice {
179 _info: DeviceInfo,
180 _proxy: fadevice::ObserverProxy,
181
182 pub signal_processing: Option<SignalProcessing>,
184}
185
186impl RegistryDevice {
187 pub fn new(info: DeviceInfo, proxy: fadevice::ObserverProxy) -> Self {
188 let is_signal_processing_supported = info.0.signal_processing_elements.is_some()
189 && info.0.signal_processing_topologies.is_some();
190 let signal_processing =
191 is_signal_processing_supported.then(|| SignalProcessing::new(proxy.clone()));
192
193 Self { _info: info, _proxy: proxy, signal_processing }
194 }
195}
196
197pub struct SignalProcessing {
199 proxy: fadevice::ObserverProxy,
200
201 element_states: Arc<Mutex<Option<BTreeMap<fadevice::ElementId, ElementState>>>>,
202 topology_id: Arc<Mutex<Option<fadevice::TopologyId>>>,
203
204 element_states_initialized: AsyncEvent,
205 topology_id_initialized: AsyncEvent,
206
207 _watch_element_states_task: Task<()>,
208 _watch_topology_task: Task<()>,
209}
210
211impl SignalProcessing {
212 fn new(proxy: fadevice::ObserverProxy) -> Self {
213 let element_states = Arc::new(Mutex::new(None));
214 let topology_id = Arc::new(Mutex::new(None));
215
216 let element_states_initialized = AsyncEvent::new();
217 let topology_id_initialized = AsyncEvent::new();
218
219 let watch_element_states_task = Task::spawn({
220 let proxy = proxy.clone();
221 let element_states = element_states.clone();
222 let element_states_initialized = element_states_initialized.clone();
223 async move {
224 if let Err(err) =
225 watch_element_states(proxy, element_states, element_states_initialized.clone())
226 .await
227 {
228 error!(err:%; "Failed to watch Registry element states");
229 element_states_initialized.signal();
233 }
234 }
235 });
236
237 let watch_topology_task = Task::spawn({
238 let proxy = proxy.clone();
239 let topology_id = topology_id.clone();
240 let topology_id_initialized = topology_id_initialized.clone();
241 async move {
242 if let Err(err) =
243 watch_topology(proxy, topology_id, topology_id_initialized.clone()).await
244 {
245 error!(err:%; "Failed to watch Registry topology");
246 topology_id_initialized.signal();
250 }
251 }
252 });
253
254 Self {
255 proxy,
256 element_states,
257 topology_id,
258 element_states_initialized,
259 topology_id_initialized,
260 _watch_element_states_task: watch_element_states_task,
261 _watch_topology_task: watch_topology_task,
262 }
263 }
264
265 pub async fn elements(&self) -> Result<Option<Vec<Element>>, Error> {
268 let response = self
269 .proxy
270 .get_elements()
271 .await
272 .context("failed to call GetElements")?
273 .map_err(|status| Status::from_raw(status));
274
275 if let Err(Status::NOT_SUPPORTED) = response {
276 return Ok(None);
277 }
278
279 let elements = response
280 .context("failed to get elements")?
281 .into_iter()
282 .map(TryInto::try_into)
283 .collect::<Result<Vec<_>, _>>()
284 .map_err(|err| anyhow!("Invalid element: {}", err))?;
285
286 Ok(Some(elements))
287 }
288
289 pub async fn topologies(&self) -> Result<Option<Vec<Topology>>, Error> {
292 let response = self
293 .proxy
294 .get_topologies()
295 .await
296 .context("failed to call GetTopologies")?
297 .map_err(|status| Status::from_raw(status));
298
299 if let Err(Status::NOT_SUPPORTED) = response {
300 return Ok(None);
301 }
302
303 let topologies = response
304 .context("failed to get topologies")?
305 .into_iter()
306 .map(TryInto::try_into)
307 .collect::<Result<Vec<_>, _>>()
308 .map_err(|err| anyhow!("Invalid topology: {}", err))?;
309
310 Ok(Some(topologies))
311 }
312
313 pub async fn topology_id(&self) -> Option<fadevice::TopologyId> {
316 self.topology_id_initialized.wait().await;
317 *self.topology_id.lock().await
318 }
319
320 pub async fn element_state(&self, element_id: fadevice::ElementId) -> Option<ElementState> {
325 self.element_states_initialized.wait().await;
326 self.element_states
327 .lock()
328 .await
329 .as_ref()
330 .and_then(|states| states.get(&element_id).cloned())
331 }
332
333 pub async fn element_states(&self) -> Option<BTreeMap<fadevice::ElementId, ElementState>> {
336 self.element_states_initialized.wait().await;
337 self.element_states.lock().await.clone()
338 }
339}
340
341async fn watch_element_states(
347 proxy: fadevice::ObserverProxy,
348 element_states: Arc<Mutex<Option<BTreeMap<fadevice::ElementId, ElementState>>>>,
349 element_states_initialized: AsyncEvent,
350) -> Result<(), Error> {
351 let mut element_states_initialized = Some(element_states_initialized);
352
353 let element_ids = {
354 let get_elements_response = proxy
355 .get_elements()
356 .await
357 .context("failed to call GetElements")?
358 .map_err(|status| Status::from_raw(status));
359
360 if let Err(Status::NOT_SUPPORTED) = get_elements_response {
361 element_states_initialized.take().unwrap().signal();
362 return Ok(());
363 }
364
365 get_elements_response
366 .context("failed to get elements")?
367 .into_iter()
368 .map(|element| element.id.ok_or_else(|| anyhow!("missing element 'id'")))
369 .collect::<Result<Vec<_>, _>>()?
370 };
371
372 let mut uninitialized_element_ids = BTreeSet::from_iter(element_ids.iter().copied());
374
375 let state_streams = element_ids.into_iter().map(|element_id| {
376 HangingGetStream::new(proxy.clone(), move |p| p.watch_element_state(element_id))
377 .map(move |element_state_result| (element_id, element_state_result))
378 });
379
380 let mut all_states_stream = futures::stream::select_all(state_streams);
381
382 while let Some((element_id, element_state_result)) = all_states_stream.next().await {
383 let element_state: ElementState = element_state_result
384 .context("failed to call WatchElementState")?
385 .try_into()
386 .map_err(|err| anyhow!("Invalid element state: {}", err))?;
387 let mut element_states = element_states.lock().await;
388 let element_states_map = element_states.get_or_insert_with(|| BTreeMap::new());
389 let _ = element_states_map.insert(element_id, element_state);
390
391 if element_states_initialized.is_some() {
393 let _ = uninitialized_element_ids.remove(&element_id);
394 if uninitialized_element_ids.is_empty() {
395 element_states_initialized.take().unwrap().signal();
396 }
397 }
398 }
399
400 Ok(())
401}
402
403async fn watch_topology(
409 proxy: fadevice::ObserverProxy,
410 topology_id: Arc<Mutex<Option<fadevice::TopologyId>>>,
411 topology_id_initialized: AsyncEvent,
412) -> Result<(), Error> {
413 let mut topology_id_initialized = Some(topology_id_initialized);
414
415 let mut topology_stream =
416 HangingGetStream::new(proxy.clone(), fadevice::ObserverProxy::watch_topology);
417
418 while let Some(topology_result) = topology_stream.next().await {
419 let new_topology_id = topology_result.context("failed to call WatchTopology")?;
420
421 *topology_id.lock().await = Some(new_topology_id);
422
423 if let Some(topology_id_initialized) = topology_id_initialized.take() {
424 topology_id_initialized.signal();
425 }
426 }
427
428 Ok(())
429}
430
431#[cfg(test)]
432mod test {
433 use super::*;
434 use async_utils::hanging_get::server::{HangingGet, Publisher};
435 use fidl::endpoints::spawn_local_stream_handler;
436
437 type AddedResponse = fadevice::RegistryWatchDevicesAddedResponse;
438 type AddedResponder = fadevice::RegistryWatchDevicesAddedResponder;
439 type AddedNotifyFn = Box<dyn Fn(&AddedResponse, AddedResponder) -> bool>;
440 type AddedPublisher = Publisher<AddedResponse, AddedResponder, AddedNotifyFn>;
441
442 type RemovedResponse = fadevice::RegistryWatchDeviceRemovedResponse;
443 type RemovedResponder = fadevice::RegistryWatchDeviceRemovedResponder;
444 type RemovedNotifyFn = Box<dyn Fn(&RemovedResponse, RemovedResponder) -> bool>;
445 type RemovedPublisher = Publisher<RemovedResponse, RemovedResponder, RemovedNotifyFn>;
446
447 fn serve_registry(
448 initial_devices: Vec<fadevice::Info>,
449 ) -> (fadevice::RegistryProxy, AddedPublisher, RemovedPublisher) {
450 let initial_added_response =
451 AddedResponse { devices: Some(initial_devices), ..Default::default() };
452 let watch_devices_added_notify: AddedNotifyFn =
453 Box::new(|response, responder: AddedResponder| {
454 responder.send(Ok(response)).expect("failed to send response");
455 true
456 });
457 let mut added_broker = HangingGet::new(initial_added_response, watch_devices_added_notify);
458 let added_publisher = added_broker.new_publisher();
459
460 let watch_device_removed_notify: RemovedNotifyFn =
461 Box::new(|response, responder: RemovedResponder| {
462 responder.send(Ok(response)).expect("failed to send response");
463 true
464 });
465 let mut removed_broker = HangingGet::new_unknown_state(watch_device_removed_notify);
466 let removed_publisher = removed_broker.new_publisher();
467
468 let added_subscriber = Arc::new(Mutex::new(added_broker.new_subscriber()));
469 let removed_subscriber = Arc::new(Mutex::new(removed_broker.new_subscriber()));
470
471 let proxy = spawn_local_stream_handler(move |request| {
472 let added_subscriber = added_subscriber.clone();
473 let removed_subscriber = removed_subscriber.clone();
474 async move {
475 match request {
476 fadevice::RegistryRequest::WatchDevicesAdded { responder } => {
477 added_subscriber.lock().await.register(responder).unwrap()
478 }
479 fadevice::RegistryRequest::WatchDeviceRemoved { responder } => {
480 removed_subscriber.lock().await.register(responder).unwrap()
481 }
482 _ => unimplemented!(),
483 }
484 }
485 });
486
487 (proxy, added_publisher, removed_publisher)
488 }
489
490 fn added_response(devices: Vec<fadevice::Info>) -> fadevice::RegistryWatchDevicesAddedResponse {
491 fadevice::RegistryWatchDevicesAddedResponse { devices: Some(devices), ..Default::default() }
492 }
493
494 fn removed_response(
495 token_id: fadevice::TokenId,
496 ) -> fadevice::RegistryWatchDeviceRemovedResponse {
497 fadevice::RegistryWatchDeviceRemovedResponse {
498 token_id: Some(token_id),
499 ..Default::default()
500 }
501 }
502
503 #[fuchsia::test]
504 async fn test_device_info() {
505 let initial_devices = vec![fadevice::Info { token_id: Some(1), ..Default::default() }];
506 let (registry_proxy, _added_publisher, _removed_publisher) =
507 serve_registry(initial_devices);
508 let registry = Registry::new(registry_proxy);
509
510 assert!(registry.device_info(1).await.is_some());
511 assert!(registry.device_info(2).await.is_none());
512 }
513
514 #[fuchsia::test]
515 async fn test_subscribe() {
516 let initial_devices = vec![];
517 let (registry_proxy, added_publisher, removed_publisher) = serve_registry(initial_devices);
518 let registry = Registry::new(registry_proxy);
519
520 registry.devices_initialized.wait().await;
521
522 let mut events_receiver = registry.subscribe().await;
523
524 added_publisher.set(added_response(vec![
526 fadevice::Info { token_id: Some(1), ..Default::default() },
527 fadevice::Info { token_id: Some(2), ..Default::default() },
528 ]));
529
530 let events: Vec<_> = events_receiver.by_ref().take(2).collect().await;
532
533 let mut added_token_ids: Vec<_> = events
534 .iter()
535 .filter_map(|event| match event {
536 DeviceEvent::Added(info) => Some(info.token_id()),
537 _ => None,
538 })
539 .collect();
540 added_token_ids.sort();
541 assert_eq!(added_token_ids, vec![1, 2]);
542
543 removed_publisher.set(removed_response(2));
545
546 let events: Vec<_> = events_receiver.take(1).collect().await;
548
549 let mut removed_token_ids: Vec<_> = events
550 .iter()
551 .filter_map(|event| match event {
552 DeviceEvent::Removed(token_id) => Some(*token_id),
553 _ => None,
554 })
555 .collect();
556 removed_token_ids.sort();
557 assert_eq!(removed_token_ids, vec![2]);
558 }
559}