fuchsia_audio/
registry.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use crate::device::Info as DeviceInfo;
6use crate::sigproc::{Element, ElementState, Topology};
7use anyhow::{Context, Error, anyhow};
8use async_utils::event::Event as AsyncEvent;
9use async_utils::hanging_get::client::HangingGetStream;
10use fidl::endpoints::create_proxy;
11use fidl_fuchsia_audio_device as fadevice;
12use fuchsia_async::Task;
13use futures::StreamExt;
14use futures::channel::mpsc::{self, UnboundedReceiver, UnboundedSender};
15use futures::lock::Mutex;
16use log::{error, warn};
17use std::collections::{BTreeMap, BTreeSet};
18use std::sync::Arc;
19use zx_status::Status;
20
21#[derive(Debug, Clone)]
22pub enum DeviceEvent {
23    /// A device was added to the registry.
24    Added(Box<DeviceInfo>),
25    /// A device was removed from the registry.
26    Removed(fadevice::TokenId),
27}
28
29pub type DeviceEventSender = UnboundedSender<DeviceEvent>;
30pub type DeviceEventReceiver = UnboundedReceiver<DeviceEvent>;
31
32pub struct Registry {
33    proxy: fadevice::RegistryProxy,
34    devices: Arc<Mutex<BTreeMap<fadevice::TokenId, DeviceInfo>>>,
35    devices_initialized: AsyncEvent,
36    event_senders: Arc<Mutex<Vec<DeviceEventSender>>>,
37    _watch_devices_task: Task<()>,
38}
39
40impl Registry {
41    pub fn new(proxy: fadevice::RegistryProxy) -> Self {
42        let devices = Arc::new(Mutex::new(BTreeMap::new()));
43        let devices_initialized = AsyncEvent::new();
44        let event_senders = Arc::new(Mutex::new(vec![]));
45
46        let watch_devices_task = Task::spawn({
47            let proxy = proxy.clone();
48            let devices = devices.clone();
49            let devices_initialized = devices_initialized.clone();
50            let event_senders = event_senders.clone();
51            async move {
52                // Pass a clone to watch_devices so we can signal the original on error.
53                if let Err(err) =
54                    watch_devices(proxy, devices, devices_initialized.clone(), event_senders).await
55                {
56                    warn!(err:%; "Failed to watch Registry devices");
57                    // Ensure we don't hang waiters if the watch task fails.
58                    devices_initialized.signal();
59                }
60            }
61        });
62
63        Self {
64            proxy,
65            devices,
66            devices_initialized,
67            event_senders,
68            _watch_devices_task: watch_devices_task,
69        }
70    }
71
72    /// Returns information about the device with the given `token_id`.
73    ///
74    /// Returns None if there is no device with the given ID.
75    pub async fn device_info(&self, token_id: fadevice::TokenId) -> Option<DeviceInfo> {
76        self.devices_initialized.wait().await;
77        self.devices.lock().await.get(&token_id).cloned()
78    }
79
80    /// Returns information about all devices in the registry.
81    pub async fn device_infos(&self) -> BTreeMap<fadevice::TokenId, DeviceInfo> {
82        self.devices_initialized.wait().await;
83        self.devices.lock().await.clone()
84    }
85
86    /// Returns a [RegistryDevice] that observes the device with the given `token_id`.
87    ///
88    /// Returns an error if there is no device with the given token ID.
89    pub async fn observe(&self, token_id: fadevice::TokenId) -> Result<RegistryDevice, Error> {
90        self.devices_initialized.wait().await;
91
92        let info = self
93            .devices
94            .lock()
95            .await
96            .get(&token_id)
97            .cloned()
98            .ok_or_else(|| anyhow!("Device with ID {} does not exist", token_id))?;
99
100        let (observer_proxy, observer_server) = create_proxy::<fadevice::ObserverMarker>();
101
102        let _ = self
103            .proxy
104            .create_observer(fadevice::RegistryCreateObserverRequest {
105                token_id: Some(token_id),
106                observer_server: Some(observer_server),
107                ..Default::default()
108            })
109            .await
110            .context("Failed to call CreateObserver")?
111            .map_err(|err| anyhow!("failed to create device observer: {:?}", err))?;
112
113        Ok(RegistryDevice::new(info, observer_proxy))
114    }
115
116    /// Returns a channel of device events.
117    pub async fn subscribe(&self) -> DeviceEventReceiver {
118        let (sender, receiver) = mpsc::unbounded::<DeviceEvent>();
119        self.event_senders.lock().await.push(sender);
120        receiver
121    }
122}
123
124/// Watches devices added to and removed from the registry and updates
125/// `devices` with the current state.
126///
127/// Signals `devices_initialized` when `devices` is populated with the initial
128/// set of devices.
129async fn watch_devices(
130    proxy: fadevice::RegistryProxy,
131    devices: Arc<Mutex<BTreeMap<fadevice::TokenId, DeviceInfo>>>,
132    devices_initialized: AsyncEvent,
133    event_senders: Arc<Mutex<Vec<DeviceEventSender>>>,
134) -> Result<(), Error> {
135    let mut devices_initialized = Some(devices_initialized);
136
137    let mut devices_added_stream =
138        HangingGetStream::new(proxy.clone(), fadevice::RegistryProxy::watch_devices_added);
139    let mut device_removed_stream =
140        HangingGetStream::new(proxy, fadevice::RegistryProxy::watch_device_removed);
141
142    loop {
143        futures::select! {
144            added = devices_added_stream.select_next_some() => {
145                let response = added
146                    .context("failed to call WatchDevicesAdded")?
147                    .map_err(|err| anyhow!("failed to watch for added devices: {:?}", err))?;
148                let added_devices = response.devices.ok_or_else(|| anyhow!("missing devices"))?;
149
150                let mut devices = devices.lock().await;
151                let mut event_senders = event_senders.lock().await;
152
153                for new_device in added_devices.into_iter() {
154                    let token_id = new_device.token_id.ok_or_else(|| anyhow!("device info missing token_id"))?;
155                    let device_info = DeviceInfo::from(new_device);
156                    for sender in event_senders.iter_mut() {
157                        let _ = sender.unbounded_send(DeviceEvent::Added(Box::new(device_info.clone())));
158                    }
159                    let _ = devices.insert(token_id, device_info);
160                }
161
162                if let Some(devices_initialized) = devices_initialized.take() {
163                    devices_initialized.signal();
164                }
165            },
166            removed = device_removed_stream.select_next_some() => {
167                let response = removed
168                    .context("failed to call WatchDeviceRemoved")?
169                    .map_err(|err| anyhow!("failed to watch for removed device: {:?}", err))?;
170                let token_id = response.token_id.ok_or_else(|| anyhow!("missing token_id"))?;
171                let mut devices = devices.lock().await;
172                let _ = devices.remove(&token_id);
173                for sender in event_senders.lock().await.iter_mut() {
174                    let _ = sender.unbounded_send(DeviceEvent::Removed(token_id));
175                }
176            }
177        }
178    }
179}
180
181pub struct RegistryDevice {
182    _info: DeviceInfo,
183    _proxy: fadevice::ObserverProxy,
184
185    /// If None, this device does not support signal processing.
186    pub signal_processing: Option<SignalProcessing>,
187}
188
189impl RegistryDevice {
190    pub fn new(info: DeviceInfo, proxy: fadevice::ObserverProxy) -> Self {
191        let is_signal_processing_supported = info.0.signal_processing_elements.is_some()
192            && info.0.signal_processing_topologies.is_some();
193        let signal_processing =
194            is_signal_processing_supported.then(|| SignalProcessing::new(proxy.clone()));
195
196        Self { _info: info, _proxy: proxy, signal_processing }
197    }
198}
199
200/// Client for the composed signal processing `Reader` in a `fuchsia.audio.device.Observer`.
201pub struct SignalProcessing {
202    proxy: fadevice::ObserverProxy,
203
204    element_states: Arc<Mutex<Option<BTreeMap<fadevice::ElementId, ElementState>>>>,
205    topology_id: Arc<Mutex<Option<fadevice::TopologyId>>>,
206
207    element_states_initialized: AsyncEvent,
208    topology_id_initialized: AsyncEvent,
209
210    _watch_element_states_task: Task<()>,
211    _watch_topology_task: Task<()>,
212}
213
214impl SignalProcessing {
215    fn new(proxy: fadevice::ObserverProxy) -> Self {
216        let element_states = Arc::new(Mutex::new(None));
217        let topology_id = Arc::new(Mutex::new(None));
218
219        let element_states_initialized = AsyncEvent::new();
220        let topology_id_initialized = AsyncEvent::new();
221
222        let watch_element_states_task = Task::spawn({
223            let proxy = proxy.clone();
224            let element_states = element_states.clone();
225            let element_states_initialized = element_states_initialized.clone();
226            async move {
227                if let Err(err) =
228                    watch_element_states(proxy, element_states, element_states_initialized.clone())
229                        .await
230                {
231                    error!(err:%; "Failed to watch Registry element states");
232                    // Watching the element states will fail if the device does not support signal
233                    // processing. In this case, mark the states as initialized so the getter can
234                    // return the initial None value.
235                    element_states_initialized.signal();
236                }
237            }
238        });
239
240        let watch_topology_task = Task::spawn({
241            let proxy = proxy.clone();
242            let topology_id = topology_id.clone();
243            let topology_id_initialized = topology_id_initialized.clone();
244            async move {
245                if let Err(err) =
246                    watch_topology(proxy, topology_id, topology_id_initialized.clone()).await
247                {
248                    error!(err:%; "Failed to watch Registry topology");
249                    // Watching the topology ID will fail if the device does not support signal
250                    // processing. In this case, mark the ID as initialized so the getter can
251                    // return the initial None value.
252                    topology_id_initialized.signal();
253                }
254            }
255        });
256
257        Self {
258            proxy,
259            element_states,
260            topology_id,
261            element_states_initialized,
262            topology_id_initialized,
263            _watch_element_states_task: watch_element_states_task,
264            _watch_topology_task: watch_topology_task,
265        }
266    }
267
268    /// Returns this device's signal processing elements, or `None` if the device does not support
269    /// signal processing.
270    pub async fn elements(&self) -> Result<Option<Vec<Element>>, Error> {
271        let response = self
272            .proxy
273            .get_elements()
274            .await
275            .context("failed to call GetElements")?
276            .map_err(Status::from_raw);
277
278        if let Err(Status::NOT_SUPPORTED) = response {
279            return Ok(None);
280        }
281
282        let elements = response
283            .context("failed to get elements")?
284            .into_iter()
285            .map(TryInto::try_into)
286            .collect::<Result<Vec<_>, _>>()
287            .map_err(|err| anyhow!("Invalid element: {}", err))?;
288
289        Ok(Some(elements))
290    }
291
292    /// Returns this device's signal processing topologies, or `None` if the device does not
293    /// support signal processing.
294    pub async fn topologies(&self) -> Result<Option<Vec<Topology>>, Error> {
295        let response = self
296            .proxy
297            .get_topologies()
298            .await
299            .context("failed to call GetTopologies")?
300            .map_err(Status::from_raw);
301
302        if let Err(Status::NOT_SUPPORTED) = response {
303            return Ok(None);
304        }
305
306        let topologies = response
307            .context("failed to get topologies")?
308            .into_iter()
309            .map(TryInto::try_into)
310            .collect::<Result<Vec<_>, _>>()
311            .map_err(|err| anyhow!("Invalid topology: {}", err))?;
312
313        Ok(Some(topologies))
314    }
315
316    /// Returns the current signal processing topology ID, or `None` if the device does not support
317    /// signal processing.
318    pub async fn topology_id(&self) -> Option<fadevice::TopologyId> {
319        self.topology_id_initialized.wait().await;
320        *self.topology_id.lock().await
321    }
322
323    /// Returns the state of the signal processing element with the given `element_id`.
324    ///
325    /// Returns None if there is no element with the given ID, or if the device does not support
326    /// signal processing.
327    pub async fn element_state(&self, element_id: fadevice::ElementId) -> Option<ElementState> {
328        self.element_states_initialized.wait().await;
329        self.element_states
330            .lock()
331            .await
332            .as_ref()
333            .and_then(|states| states.get(&element_id).cloned())
334    }
335
336    /// Returns states of all signal processing elements, or `None` if the device does not support
337    /// signal processing.
338    pub async fn element_states(&self) -> Option<BTreeMap<fadevice::ElementId, ElementState>> {
339        self.element_states_initialized.wait().await;
340        self.element_states.lock().await.clone()
341    }
342}
343
344/// Watches element state changes on a registry device and updates `element_states` with the
345/// current state for each element.
346///
347/// Signals `element_states_initialized` when `element_states` is populated
348/// with the initial set of states.
349async fn watch_element_states(
350    proxy: fadevice::ObserverProxy,
351    element_states: Arc<Mutex<Option<BTreeMap<fadevice::ElementId, ElementState>>>>,
352    element_states_initialized: AsyncEvent,
353) -> Result<(), Error> {
354    let mut element_states_initialized = Some(element_states_initialized);
355
356    let element_ids = {
357        let get_elements_response = proxy
358            .get_elements()
359            .await
360            .context("failed to call GetElements")?
361            .map_err(Status::from_raw);
362
363        if let Err(Status::NOT_SUPPORTED) = get_elements_response {
364            element_states_initialized.take().unwrap().signal();
365            return Ok(());
366        }
367
368        get_elements_response
369            .context("failed to get elements")?
370            .into_iter()
371            .map(|element| element.id.ok_or_else(|| anyhow!("missing element 'id'")))
372            .collect::<Result<Vec<_>, _>>()?
373    };
374
375    // Contains element IDs for which we haven't received an initial state.
376    let mut uninitialized_element_ids = BTreeSet::from_iter(element_ids.iter().copied());
377
378    let state_streams = element_ids.into_iter().map(|element_id| {
379        HangingGetStream::new(proxy.clone(), move |p| p.watch_element_state(element_id))
380            .map(move |element_state_result| (element_id, element_state_result))
381    });
382
383    let mut all_states_stream = futures::stream::select_all(state_streams);
384
385    while let Some((element_id, element_state_result)) = all_states_stream.next().await {
386        let element_state: ElementState = element_state_result
387            .context("failed to call WatchElementState")?
388            .try_into()
389            .map_err(|err| anyhow!("Invalid element state: {}", err))?;
390        let mut element_states = element_states.lock().await;
391        let element_states_map = element_states.get_or_insert_with(BTreeMap::new);
392        let _ = element_states_map.insert(element_id, element_state);
393
394        // Signal `element_states_initialized` once all elements have initial states.
395        if element_states_initialized.is_some() {
396            let _ = uninitialized_element_ids.remove(&element_id);
397            if uninitialized_element_ids.is_empty() {
398                element_states_initialized.take().unwrap().signal();
399            }
400        }
401    }
402
403    Ok(())
404}
405
406/// Watches topology changes on a registry device and updates `topology_id`
407/// when the topology changes.
408///
409/// Signals `topology_id_initialized` when `topology_id` is populated
410/// with the initial topology.
411async fn watch_topology(
412    proxy: fadevice::ObserverProxy,
413    topology_id: Arc<Mutex<Option<fadevice::TopologyId>>>,
414    topology_id_initialized: AsyncEvent,
415) -> Result<(), Error> {
416    let mut topology_id_initialized = Some(topology_id_initialized);
417
418    let mut topology_stream =
419        HangingGetStream::new(proxy.clone(), fadevice::ObserverProxy::watch_topology);
420
421    while let Some(topology_result) = topology_stream.next().await {
422        let new_topology_id = topology_result.context("failed to call WatchTopology")?;
423
424        *topology_id.lock().await = Some(new_topology_id);
425
426        if let Some(topology_id_initialized) = topology_id_initialized.take() {
427            topology_id_initialized.signal();
428        }
429    }
430
431    Ok(())
432}
433
434#[cfg(test)]
435mod test {
436    use std::rc::Rc;
437
438    use super::*;
439    use async_utils::hanging_get::server::{HangingGet, Publisher};
440    use fidl_test_util::spawn_local_stream_handler;
441
442    type AddedResponse = fadevice::RegistryWatchDevicesAddedResponse;
443    type AddedResponder = fadevice::RegistryWatchDevicesAddedResponder;
444    type AddedNotifyFn = Box<dyn Fn(&AddedResponse, AddedResponder) -> bool>;
445    type AddedPublisher = Publisher<AddedResponse, AddedResponder, AddedNotifyFn>;
446
447    type RemovedResponse = fadevice::RegistryWatchDeviceRemovedResponse;
448    type RemovedResponder = fadevice::RegistryWatchDeviceRemovedResponder;
449    type RemovedNotifyFn = Box<dyn Fn(&RemovedResponse, RemovedResponder) -> bool>;
450    type RemovedPublisher = Publisher<RemovedResponse, RemovedResponder, RemovedNotifyFn>;
451
452    fn serve_registry(
453        initial_devices: Vec<fadevice::Info>,
454    ) -> (fadevice::RegistryProxy, AddedPublisher, RemovedPublisher) {
455        let initial_added_response =
456            AddedResponse { devices: Some(initial_devices), ..Default::default() };
457        let watch_devices_added_notify: AddedNotifyFn =
458            Box::new(|response, responder: AddedResponder| {
459                responder.send(Ok(response)).expect("failed to send response");
460                true
461            });
462        let mut added_broker = HangingGet::new(initial_added_response, watch_devices_added_notify);
463        let added_publisher = added_broker.new_publisher();
464
465        let watch_device_removed_notify: RemovedNotifyFn =
466            Box::new(|response, responder: RemovedResponder| {
467                responder.send(Ok(response)).expect("failed to send response");
468                true
469            });
470        let mut removed_broker = HangingGet::new_unknown_state(watch_device_removed_notify);
471        let removed_publisher = removed_broker.new_publisher();
472
473        let added_subscriber = Rc::new(Mutex::new(added_broker.new_subscriber()));
474        let removed_subscriber = Rc::new(Mutex::new(removed_broker.new_subscriber()));
475
476        let proxy = spawn_local_stream_handler(move |request| {
477            let added_subscriber = added_subscriber.clone();
478            let removed_subscriber = removed_subscriber.clone();
479            async move {
480                match request {
481                    fadevice::RegistryRequest::WatchDevicesAdded { responder } => {
482                        added_subscriber.lock().await.register(responder).unwrap()
483                    }
484                    fadevice::RegistryRequest::WatchDeviceRemoved { responder } => {
485                        removed_subscriber.lock().await.register(responder).unwrap()
486                    }
487                    _ => unimplemented!(),
488                }
489            }
490        });
491
492        (proxy, added_publisher, removed_publisher)
493    }
494
495    fn added_response(devices: Vec<fadevice::Info>) -> fadevice::RegistryWatchDevicesAddedResponse {
496        fadevice::RegistryWatchDevicesAddedResponse { devices: Some(devices), ..Default::default() }
497    }
498
499    fn removed_response(
500        token_id: fadevice::TokenId,
501    ) -> fadevice::RegistryWatchDeviceRemovedResponse {
502        fadevice::RegistryWatchDeviceRemovedResponse {
503            token_id: Some(token_id),
504            ..Default::default()
505        }
506    }
507
508    #[fuchsia::test]
509    async fn test_device_info() {
510        let initial_devices = vec![fadevice::Info { token_id: Some(1), ..Default::default() }];
511        let (registry_proxy, _added_publisher, _removed_publisher) =
512            serve_registry(initial_devices);
513        let registry = Registry::new(registry_proxy);
514
515        assert!(registry.device_info(1).await.is_some());
516        assert!(registry.device_info(2).await.is_none());
517    }
518
519    #[fuchsia::test]
520    async fn test_subscribe() {
521        let initial_devices = vec![];
522        let (registry_proxy, added_publisher, removed_publisher) = serve_registry(initial_devices);
523        let registry = Registry::new(registry_proxy);
524
525        registry.devices_initialized.wait().await;
526
527        let mut events_receiver = registry.subscribe().await;
528
529        // Publish a WatchDevicesAdded response with two devices and verify that we receive it.
530        added_publisher.set(added_response(vec![
531            fadevice::Info { token_id: Some(1), ..Default::default() },
532            fadevice::Info { token_id: Some(2), ..Default::default() },
533        ]));
534
535        // There should be two events, one for each device.
536        let events: Vec<_> = events_receiver.by_ref().take(2).collect().await;
537
538        let mut added_token_ids: Vec<_> = events
539            .iter()
540            .filter_map(|event| match event {
541                DeviceEvent::Added(info) => Some(info.token_id()),
542                _ => None,
543            })
544            .collect();
545        added_token_ids.sort();
546        assert_eq!(added_token_ids, vec![1, 2]);
547
548        // Publish a WatchDeviceRemoved response and verify that we receive it.
549        removed_publisher.set(removed_response(2));
550
551        // There should be one event.
552        let events: Vec<_> = events_receiver.take(1).collect().await;
553
554        let mut removed_token_ids: Vec<_> = events
555            .iter()
556            .filter_map(|event| match event {
557                DeviceEvent::Removed(token_id) => Some(*token_id),
558                _ => None,
559            })
560            .collect();
561        removed_token_ids.sort();
562        assert_eq!(removed_token_ids, vec![2]);
563    }
564}