Skip to main content

fuchsia_audio/
registry.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use crate::device::Info as DeviceInfo;
6use crate::sigproc::{Element, ElementState, Topology};
7use anyhow::{Context, Error, anyhow};
8use async_utils::event::Event as AsyncEvent;
9use async_utils::hanging_get::client::HangingGetStream;
10use flex_client::ProxyHasDomain;
11use flex_fuchsia_audio_device as fadevice;
12use fuchsia_async::Task;
13use futures::StreamExt;
14use futures::channel::mpsc::{self, UnboundedReceiver, UnboundedSender};
15use futures::lock::Mutex;
16use log::{error, warn};
17use std::collections::{BTreeMap, BTreeSet};
18use std::sync::Arc;
19use zx_status::Status;
20
21#[derive(Debug, Clone)]
22pub enum DeviceEvent {
23    /// A device was added to the registry.
24    Added(Box<DeviceInfo>),
25    /// A device was removed from the registry.
26    Removed(fadevice::TokenId),
27}
28
29pub type DeviceEventSender = UnboundedSender<DeviceEvent>;
30pub type DeviceEventReceiver = UnboundedReceiver<DeviceEvent>;
31
32pub struct Registry {
33    proxy: fadevice::RegistryProxy,
34    devices: Arc<Mutex<BTreeMap<fadevice::TokenId, DeviceInfo>>>,
35    devices_initialized: AsyncEvent,
36    event_senders: Arc<Mutex<Vec<DeviceEventSender>>>,
37    _watch_devices_task: Task<()>,
38}
39
40impl Registry {
41    pub fn new(proxy: fadevice::RegistryProxy) -> Self {
42        let devices = Arc::new(Mutex::new(BTreeMap::new()));
43        let devices_initialized = AsyncEvent::new();
44        let event_senders = Arc::new(Mutex::new(vec![]));
45
46        let watch_devices_task = Task::spawn({
47            let proxy = proxy.clone();
48            let devices = devices.clone();
49            let devices_initialized = devices_initialized.clone();
50            let event_senders = event_senders.clone();
51            async move {
52                // Pass a clone to watch_devices so we can signal the original on error.
53                if let Err(err) =
54                    watch_devices(proxy, devices, devices_initialized.clone(), event_senders).await
55                {
56                    warn!(err:%; "Failed to watch Registry devices");
57                    // Ensure we don't hang waiters if the watch task fails.
58                    devices_initialized.signal();
59                }
60            }
61        });
62
63        Self {
64            proxy,
65            devices,
66            devices_initialized,
67            event_senders,
68            _watch_devices_task: watch_devices_task,
69        }
70    }
71
72    /// Returns information about the device with the given `token_id`.
73    ///
74    /// Returns None if there is no device with the given ID.
75    pub async fn device_info(&self, token_id: fadevice::TokenId) -> Option<DeviceInfo> {
76        self.devices_initialized.wait().await;
77        self.devices.lock().await.get(&token_id).cloned()
78    }
79
80    /// Returns information about all devices in the registry.
81    pub async fn device_infos(&self) -> BTreeMap<fadevice::TokenId, DeviceInfo> {
82        self.devices_initialized.wait().await;
83        self.devices.lock().await.clone()
84    }
85
86    /// Returns a [RegistryDevice] that observes the device with the given `token_id`.
87    ///
88    /// Returns an error if there is no device with the given token ID.
89    pub async fn observe(&self, token_id: fadevice::TokenId) -> Result<RegistryDevice, Error> {
90        self.devices_initialized.wait().await;
91
92        let info = self
93            .devices
94            .lock()
95            .await
96            .get(&token_id)
97            .cloned()
98            .ok_or_else(|| anyhow!("Device with ID {} does not exist", token_id))?;
99
100        let (observer_proxy, observer_server) =
101            self.proxy.domain().create_proxy::<fadevice::ObserverMarker>();
102
103        let _ = self
104            .proxy
105            .create_observer(fadevice::RegistryCreateObserverRequest {
106                token_id: Some(token_id),
107                observer_server: Some(observer_server),
108                ..Default::default()
109            })
110            .await
111            .context("Failed to call CreateObserver")?
112            .map_err(|err| anyhow!("failed to create device observer: {:?}", err))?;
113
114        Ok(RegistryDevice::new(info, observer_proxy))
115    }
116
117    /// Returns a channel of device events.
118    pub async fn subscribe(&self) -> DeviceEventReceiver {
119        let (sender, receiver) = mpsc::unbounded::<DeviceEvent>();
120        self.event_senders.lock().await.push(sender);
121        receiver
122    }
123}
124
125/// Watches devices added to and removed from the registry and updates
126/// `devices` with the current state.
127///
128/// Signals `devices_initialized` when `devices` is populated with the initial
129/// set of devices.
130async fn watch_devices(
131    proxy: fadevice::RegistryProxy,
132    devices: Arc<Mutex<BTreeMap<fadevice::TokenId, DeviceInfo>>>,
133    devices_initialized: AsyncEvent,
134    event_senders: Arc<Mutex<Vec<DeviceEventSender>>>,
135) -> Result<(), Error> {
136    let mut devices_initialized = Some(devices_initialized);
137
138    let mut devices_added_stream =
139        HangingGetStream::new(proxy.clone(), fadevice::RegistryProxy::watch_devices_added);
140    let mut device_removed_stream =
141        HangingGetStream::new(proxy, fadevice::RegistryProxy::watch_device_removed);
142
143    loop {
144        futures::select! {
145            added = devices_added_stream.select_next_some() => {
146                let response = added
147                    .context("failed to call WatchDevicesAdded")?
148                    .map_err(|err| anyhow!("failed to watch for added devices: {:?}", err))?;
149                let added_devices = response.devices.ok_or_else(|| anyhow!("missing devices"))?;
150
151                let mut devices = devices.lock().await;
152                let mut event_senders = event_senders.lock().await;
153
154                for new_device in added_devices.into_iter() {
155                    let token_id = new_device.token_id.ok_or_else(|| anyhow!("device info missing token_id"))?;
156                    let device_info = DeviceInfo::from(new_device);
157                    for sender in event_senders.iter_mut() {
158                        let _ = sender.unbounded_send(DeviceEvent::Added(Box::new(device_info.clone())));
159                    }
160                    let _ = devices.insert(token_id, device_info);
161                }
162
163                if let Some(devices_initialized) = devices_initialized.take() {
164                    devices_initialized.signal();
165                }
166            },
167            removed = device_removed_stream.select_next_some() => {
168                let response = removed
169                    .context("failed to call WatchDeviceRemoved")?
170                    .map_err(|err| anyhow!("failed to watch for removed device: {:?}", err))?;
171                let token_id = response.token_id.ok_or_else(|| anyhow!("missing token_id"))?;
172                let mut devices = devices.lock().await;
173                let _ = devices.remove(&token_id);
174                for sender in event_senders.lock().await.iter_mut() {
175                    let _ = sender.unbounded_send(DeviceEvent::Removed(token_id));
176                }
177            }
178        }
179    }
180}
181
182pub struct RegistryDevice {
183    _info: DeviceInfo,
184    _proxy: fadevice::ObserverProxy,
185
186    /// If None, this device does not support signal processing.
187    pub signal_processing: Option<SignalProcessing>,
188}
189
190impl RegistryDevice {
191    pub fn new(info: DeviceInfo, proxy: fadevice::ObserverProxy) -> Self {
192        let is_signal_processing_supported = info.0.signal_processing_elements.is_some()
193            && info.0.signal_processing_topologies.is_some();
194        let signal_processing =
195            is_signal_processing_supported.then(|| SignalProcessing::new(proxy.clone()));
196
197        Self { _info: info, _proxy: proxy, signal_processing }
198    }
199}
200
201/// Client for the composed signal processing `Reader` in a `fuchsia.audio.device.Observer`.
202pub struct SignalProcessing {
203    proxy: fadevice::ObserverProxy,
204
205    element_states: Arc<Mutex<Option<BTreeMap<fadevice::ElementId, ElementState>>>>,
206    topology_id: Arc<Mutex<Option<fadevice::TopologyId>>>,
207
208    element_states_initialized: AsyncEvent,
209    topology_id_initialized: AsyncEvent,
210
211    _watch_element_states_task: Task<()>,
212    _watch_topology_task: Task<()>,
213}
214
215impl SignalProcessing {
216    fn new(proxy: fadevice::ObserverProxy) -> Self {
217        let element_states = Arc::new(Mutex::new(None));
218        let topology_id = Arc::new(Mutex::new(None));
219
220        let element_states_initialized = AsyncEvent::new();
221        let topology_id_initialized = AsyncEvent::new();
222
223        let watch_element_states_task = Task::spawn({
224            let proxy = proxy.clone();
225            let element_states = element_states.clone();
226            let element_states_initialized = element_states_initialized.clone();
227            async move {
228                if let Err(err) =
229                    watch_element_states(proxy, element_states, element_states_initialized.clone())
230                        .await
231                {
232                    error!(err:%; "Failed to watch Registry element states");
233                    // Watching the element states will fail if the device does not support signal
234                    // processing. In this case, mark the states as initialized so the getter can
235                    // return the initial None value.
236                    element_states_initialized.signal();
237                }
238            }
239        });
240
241        let watch_topology_task = Task::spawn({
242            let proxy = proxy.clone();
243            let topology_id = topology_id.clone();
244            let topology_id_initialized = topology_id_initialized.clone();
245            async move {
246                if let Err(err) =
247                    watch_topology(proxy, topology_id, topology_id_initialized.clone()).await
248                {
249                    error!(err:%; "Failed to watch Registry topology");
250                    // Watching the topology ID will fail if the device does not support signal
251                    // processing. In this case, mark the ID as initialized so the getter can
252                    // return the initial None value.
253                    topology_id_initialized.signal();
254                }
255            }
256        });
257
258        Self {
259            proxy,
260            element_states,
261            topology_id,
262            element_states_initialized,
263            topology_id_initialized,
264            _watch_element_states_task: watch_element_states_task,
265            _watch_topology_task: watch_topology_task,
266        }
267    }
268
269    /// Returns this device's signal processing elements, or `None` if the device does not support
270    /// signal processing.
271    pub async fn elements(&self) -> Result<Option<Vec<Element>>, Error> {
272        let response = self
273            .proxy
274            .get_elements()
275            .await
276            .context("failed to call GetElements")?
277            .map_err(Status::from_raw);
278
279        if let Err(Status::NOT_SUPPORTED) = response {
280            return Ok(None);
281        }
282
283        let elements = response
284            .context("failed to get elements")?
285            .into_iter()
286            .map(TryInto::try_into)
287            .collect::<Result<Vec<_>, _>>()
288            .map_err(|err| anyhow!("Invalid element: {}", err))?;
289
290        Ok(Some(elements))
291    }
292
293    /// Returns this device's signal processing topologies, or `None` if the device does not
294    /// support signal processing.
295    pub async fn topologies(&self) -> Result<Option<Vec<Topology>>, Error> {
296        let response = self
297            .proxy
298            .get_topologies()
299            .await
300            .context("failed to call GetTopologies")?
301            .map_err(Status::from_raw);
302
303        if let Err(Status::NOT_SUPPORTED) = response {
304            return Ok(None);
305        }
306
307        let topologies = response
308            .context("failed to get topologies")?
309            .into_iter()
310            .map(TryInto::try_into)
311            .collect::<Result<Vec<_>, _>>()
312            .map_err(|err| anyhow!("Invalid topology: {}", err))?;
313
314        Ok(Some(topologies))
315    }
316
317    /// Returns the current signal processing topology ID, or `None` if the device does not support
318    /// signal processing.
319    pub async fn topology_id(&self) -> Option<fadevice::TopologyId> {
320        self.topology_id_initialized.wait().await;
321        *self.topology_id.lock().await
322    }
323
324    /// Returns the state of the signal processing element with the given `element_id`.
325    ///
326    /// Returns None if there is no element with the given ID, or if the device does not support
327    /// signal processing.
328    pub async fn element_state(&self, element_id: fadevice::ElementId) -> Option<ElementState> {
329        self.element_states_initialized.wait().await;
330        self.element_states
331            .lock()
332            .await
333            .as_ref()
334            .and_then(|states| states.get(&element_id).cloned())
335    }
336
337    /// Returns states of all signal processing elements, or `None` if the device does not support
338    /// signal processing.
339    pub async fn element_states(&self) -> Option<BTreeMap<fadevice::ElementId, ElementState>> {
340        self.element_states_initialized.wait().await;
341        self.element_states.lock().await.clone()
342    }
343}
344
345/// Watches element state changes on a registry device and updates `element_states` with the
346/// current state for each element.
347///
348/// Signals `element_states_initialized` when `element_states` is populated
349/// with the initial set of states.
350async fn watch_element_states(
351    proxy: fadevice::ObserverProxy,
352    element_states: Arc<Mutex<Option<BTreeMap<fadevice::ElementId, ElementState>>>>,
353    element_states_initialized: AsyncEvent,
354) -> Result<(), Error> {
355    let mut element_states_initialized = Some(element_states_initialized);
356
357    let element_ids = {
358        let get_elements_response = proxy
359            .get_elements()
360            .await
361            .context("failed to call GetElements")?
362            .map_err(Status::from_raw);
363
364        if let Err(Status::NOT_SUPPORTED) = get_elements_response {
365            element_states_initialized.take().unwrap().signal();
366            return Ok(());
367        }
368
369        get_elements_response
370            .context("failed to get elements")?
371            .into_iter()
372            .map(|element| element.id.ok_or_else(|| anyhow!("missing element 'id'")))
373            .collect::<Result<Vec<_>, _>>()?
374    };
375
376    // Contains element IDs for which we haven't received an initial state.
377    let mut uninitialized_element_ids = BTreeSet::from_iter(element_ids.iter().copied());
378
379    let state_streams = element_ids.into_iter().map(|element_id| {
380        HangingGetStream::new(proxy.clone(), move |p| p.watch_element_state(element_id))
381            .map(move |element_state_result| (element_id, element_state_result))
382    });
383
384    let mut all_states_stream = futures::stream::select_all(state_streams);
385
386    while let Some((element_id, element_state_result)) = all_states_stream.next().await {
387        let element_state: ElementState = element_state_result
388            .context("failed to call WatchElementState")?
389            .try_into()
390            .map_err(|err| anyhow!("Invalid element state: {}", err))?;
391        let mut element_states = element_states.lock().await;
392        let element_states_map = element_states.get_or_insert_with(BTreeMap::new);
393        let _ = element_states_map.insert(element_id, element_state);
394
395        // Signal `element_states_initialized` once all elements have initial states.
396        if element_states_initialized.is_some() {
397            let _ = uninitialized_element_ids.remove(&element_id);
398            if uninitialized_element_ids.is_empty() {
399                element_states_initialized.take().unwrap().signal();
400            }
401        }
402    }
403
404    Ok(())
405}
406
407/// Watches topology changes on a registry device and updates `topology_id`
408/// when the topology changes.
409///
410/// Signals `topology_id_initialized` when `topology_id` is populated
411/// with the initial topology.
412async fn watch_topology(
413    proxy: fadevice::ObserverProxy,
414    topology_id: Arc<Mutex<Option<fadevice::TopologyId>>>,
415    topology_id_initialized: AsyncEvent,
416) -> Result<(), Error> {
417    let mut topology_id_initialized = Some(topology_id_initialized);
418
419    let mut topology_stream =
420        HangingGetStream::new(proxy.clone(), fadevice::ObserverProxy::watch_topology);
421
422    while let Some(topology_result) = topology_stream.next().await {
423        let new_topology_id = topology_result.context("failed to call WatchTopology")?;
424
425        *topology_id.lock().await = Some(new_topology_id);
426
427        if let Some(topology_id_initialized) = topology_id_initialized.take() {
428            topology_id_initialized.signal();
429        }
430    }
431
432    Ok(())
433}
434
435#[cfg(test)]
436mod test {
437    use std::rc::Rc;
438
439    use super::*;
440    use async_utils::hanging_get::server::{HangingGet, Publisher};
441    #[cfg(feature = "fdomain")]
442    use fidl_test_util::fdomain::spawn_local_stream_handler;
443    #[cfg(not(feature = "fdomain"))]
444    use fidl_test_util::spawn_local_stream_handler;
445
446    type AddedResponse = fadevice::RegistryWatchDevicesAddedResponse;
447    type AddedResponder = fadevice::RegistryWatchDevicesAddedResponder;
448    type AddedNotifyFn = Box<dyn Fn(&AddedResponse, AddedResponder) -> bool>;
449    type AddedPublisher = Publisher<AddedResponse, AddedResponder, AddedNotifyFn>;
450
451    type RemovedResponse = fadevice::RegistryWatchDeviceRemovedResponse;
452    type RemovedResponder = fadevice::RegistryWatchDeviceRemovedResponder;
453    type RemovedNotifyFn = Box<dyn Fn(&RemovedResponse, RemovedResponder) -> bool>;
454    type RemovedPublisher = Publisher<RemovedResponse, RemovedResponder, RemovedNotifyFn>;
455
456    fn serve_registry(
457        initial_devices: Vec<fadevice::Info>,
458    ) -> (fadevice::RegistryProxy, AddedPublisher, RemovedPublisher) {
459        #[cfg(feature = "fdomain")]
460        let client = fdomain_local::local_client_empty();
461        let initial_added_response =
462            AddedResponse { devices: Some(initial_devices), ..Default::default() };
463        let watch_devices_added_notify: AddedNotifyFn =
464            Box::new(|response, responder: AddedResponder| {
465                responder.send(Ok(response)).expect("failed to send response");
466                true
467            });
468        let mut added_broker = HangingGet::new(initial_added_response, watch_devices_added_notify);
469        let added_publisher = added_broker.new_publisher();
470
471        let watch_device_removed_notify: RemovedNotifyFn =
472            Box::new(|response, responder: RemovedResponder| {
473                responder.send(Ok(response)).expect("failed to send response");
474                true
475            });
476        let mut removed_broker = HangingGet::new_unknown_state(watch_device_removed_notify);
477        let removed_publisher = removed_broker.new_publisher();
478
479        let added_subscriber = Rc::new(Mutex::new(added_broker.new_subscriber()));
480        let removed_subscriber = Rc::new(Mutex::new(removed_broker.new_subscriber()));
481
482        let proxy = spawn_local_stream_handler(
483            #[cfg(feature = "fdomain")]
484            &client,
485            move |request| {
486                let added_subscriber = added_subscriber.clone();
487                let removed_subscriber = removed_subscriber.clone();
488                async move {
489                    match request {
490                        fadevice::RegistryRequest::WatchDevicesAdded { responder } => {
491                            added_subscriber.lock().await.register(responder).unwrap()
492                        }
493                        fadevice::RegistryRequest::WatchDeviceRemoved { responder } => {
494                            removed_subscriber.lock().await.register(responder).unwrap()
495                        }
496                        _ => unimplemented!(),
497                    }
498                }
499            },
500        );
501
502        (proxy, added_publisher, removed_publisher)
503    }
504
505    fn added_response(devices: Vec<fadevice::Info>) -> fadevice::RegistryWatchDevicesAddedResponse {
506        fadevice::RegistryWatchDevicesAddedResponse { devices: Some(devices), ..Default::default() }
507    }
508
509    fn removed_response(
510        token_id: fadevice::TokenId,
511    ) -> fadevice::RegistryWatchDeviceRemovedResponse {
512        fadevice::RegistryWatchDeviceRemovedResponse {
513            token_id: Some(token_id),
514            ..Default::default()
515        }
516    }
517
518    #[fuchsia::test]
519    async fn test_device_info() {
520        let initial_devices = vec![fadevice::Info { token_id: Some(1), ..Default::default() }];
521        let (registry_proxy, _added_publisher, _removed_publisher) =
522            serve_registry(initial_devices);
523        let registry = Registry::new(registry_proxy);
524
525        assert!(registry.device_info(1).await.is_some());
526        assert!(registry.device_info(2).await.is_none());
527    }
528
529    #[fuchsia::test]
530    async fn test_subscribe() {
531        let initial_devices = vec![];
532        let (registry_proxy, added_publisher, removed_publisher) = serve_registry(initial_devices);
533        let registry = Registry::new(registry_proxy);
534
535        registry.devices_initialized.wait().await;
536
537        let mut events_receiver = registry.subscribe().await;
538
539        // Publish a WatchDevicesAdded response with two devices and verify that we receive it.
540        added_publisher.set(added_response(vec![
541            fadevice::Info { token_id: Some(1), ..Default::default() },
542            fadevice::Info { token_id: Some(2), ..Default::default() },
543        ]));
544
545        // There should be two events, one for each device.
546        let events: Vec<_> = events_receiver.by_ref().take(2).collect().await;
547
548        let mut added_token_ids: Vec<_> = events
549            .iter()
550            .filter_map(|event| match event {
551                DeviceEvent::Added(info) => Some(info.token_id()),
552                _ => None,
553            })
554            .collect();
555        added_token_ids.sort();
556        assert_eq!(added_token_ids, vec![1, 2]);
557
558        // Publish a WatchDeviceRemoved response and verify that we receive it.
559        removed_publisher.set(removed_response(2));
560
561        // There should be one event.
562        let events: Vec<_> = events_receiver.take(1).collect().await;
563
564        let mut removed_token_ids: Vec<_> = events
565            .iter()
566            .filter_map(|event| match event {
567                DeviceEvent::Removed(token_id) => Some(*token_id),
568                _ => None,
569            })
570            .collect();
571        removed_token_ids.sort();
572        assert_eq!(removed_token_ids, vec![2]);
573    }
574}