fuchsia_audio/
registry.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use crate::device::Info as DeviceInfo;
6use crate::sigproc::{Element, ElementState, Topology};
7use anyhow::{anyhow, Context, Error};
8use async_utils::event::Event as AsyncEvent;
9use async_utils::hanging_get::client::HangingGetStream;
10use fidl::endpoints::create_proxy;
11use fidl_fuchsia_audio_device as fadevice;
12use fuchsia_async::Task;
13use futures::channel::mpsc::{self, UnboundedReceiver, UnboundedSender};
14use futures::lock::Mutex;
15use futures::StreamExt;
16use log::error;
17use std::collections::{BTreeMap, BTreeSet};
18use std::sync::Arc;
19use zx_status::Status;
20
21#[derive(Debug, Clone)]
22pub enum DeviceEvent {
23    /// A device was added to the registry.
24    Added(Box<DeviceInfo>),
25    /// A device was removed from the registry.
26    Removed(fadevice::TokenId),
27}
28
29pub type DeviceEventSender = UnboundedSender<DeviceEvent>;
30pub type DeviceEventReceiver = UnboundedReceiver<DeviceEvent>;
31
32pub struct Registry {
33    proxy: fadevice::RegistryProxy,
34    devices: Arc<Mutex<BTreeMap<fadevice::TokenId, DeviceInfo>>>,
35    devices_initialized: AsyncEvent,
36    event_senders: Arc<Mutex<Vec<DeviceEventSender>>>,
37    _watch_devices_task: Task<()>,
38}
39
40impl Registry {
41    pub fn new(proxy: fadevice::RegistryProxy) -> Self {
42        let devices = Arc::new(Mutex::new(BTreeMap::new()));
43        let devices_initialized = AsyncEvent::new();
44        let event_senders = Arc::new(Mutex::new(vec![]));
45
46        let watch_devices_task = Task::spawn({
47            let proxy = proxy.clone();
48            let devices = devices.clone();
49            let devices_initialized = devices_initialized.clone();
50            let event_senders = event_senders.clone();
51            async {
52                if let Err(err) =
53                    watch_devices(proxy, devices, devices_initialized, event_senders).await
54                {
55                    error!(err:%; "Failed to watch Registry devices");
56                }
57            }
58        });
59
60        Self {
61            proxy,
62            devices,
63            devices_initialized,
64            event_senders,
65            _watch_devices_task: watch_devices_task,
66        }
67    }
68
69    /// Returns information about the device with the given `token_id`.
70    ///
71    /// Returns None if there is no device with the given ID.
72    pub async fn device_info(&self, token_id: fadevice::TokenId) -> Option<DeviceInfo> {
73        self.devices_initialized.wait().await;
74        self.devices.lock().await.get(&token_id).cloned()
75    }
76
77    /// Returns information about all devices in the registry.
78    pub async fn device_infos(&self) -> BTreeMap<fadevice::TokenId, DeviceInfo> {
79        self.devices_initialized.wait().await;
80        self.devices.lock().await.clone()
81    }
82
83    /// Returns a [RegistryDevice] that observes the device with the given `token_id`.
84    ///
85    /// Returns an error if there is no device with the given token ID.
86    pub async fn observe(&self, token_id: fadevice::TokenId) -> Result<RegistryDevice, Error> {
87        self.devices_initialized.wait().await;
88
89        let info = self
90            .devices
91            .lock()
92            .await
93            .get(&token_id)
94            .cloned()
95            .ok_or_else(|| anyhow!("Device with ID {} does not exist", token_id))?;
96
97        let (observer_proxy, observer_server) = create_proxy::<fadevice::ObserverMarker>();
98
99        let _ = self
100            .proxy
101            .create_observer(fadevice::RegistryCreateObserverRequest {
102                token_id: Some(token_id),
103                observer_server: Some(observer_server),
104                ..Default::default()
105            })
106            .await
107            .context("Failed to call CreateObserver")?
108            .map_err(|err| anyhow!("failed to create device observer: {:?}", err))?;
109
110        Ok(RegistryDevice::new(info, observer_proxy))
111    }
112
113    /// Returns a channel of device events.
114    pub async fn subscribe(&self) -> DeviceEventReceiver {
115        let (sender, receiver) = mpsc::unbounded::<DeviceEvent>();
116        self.event_senders.lock().await.push(sender);
117        receiver
118    }
119}
120
121/// Watches devices added to and removed from the registry and updates
122/// `devices` with the current state.
123///
124/// Signals `devices_initialized` when `devices` is populated with the initial
125/// set of devices.
126async fn watch_devices(
127    proxy: fadevice::RegistryProxy,
128    devices: Arc<Mutex<BTreeMap<fadevice::TokenId, DeviceInfo>>>,
129    devices_initialized: AsyncEvent,
130    event_senders: Arc<Mutex<Vec<DeviceEventSender>>>,
131) -> Result<(), Error> {
132    let mut devices_initialized = Some(devices_initialized);
133
134    let mut devices_added_stream =
135        HangingGetStream::new(proxy.clone(), fadevice::RegistryProxy::watch_devices_added);
136    let mut device_removed_stream =
137        HangingGetStream::new(proxy, fadevice::RegistryProxy::watch_device_removed);
138
139    loop {
140        futures::select! {
141            added = devices_added_stream.select_next_some() => {
142                let response = added
143                    .context("failed to call WatchDevicesAdded")?
144                    .map_err(|err| anyhow!("failed to watch for added devices: {:?}", err))?;
145                let added_devices = response.devices.ok_or_else(|| anyhow!("missing devices"))?;
146
147                let mut devices = devices.lock().await;
148                let mut event_senders = event_senders.lock().await;
149
150                for new_device in added_devices.into_iter() {
151                    let token_id = new_device.token_id.ok_or_else(|| anyhow!("device info missing token_id"))?;
152                    let device_info = DeviceInfo::from(new_device);
153                    for sender in event_senders.iter_mut() {
154                        let _ = sender.unbounded_send(DeviceEvent::Added(Box::new(device_info.clone())));
155                    }
156                    let _ = devices.insert(token_id, device_info);
157                }
158
159                if let Some(devices_initialized) = devices_initialized.take() {
160                    devices_initialized.signal();
161                }
162            },
163            removed = device_removed_stream.select_next_some() => {
164                let response = removed
165                    .context("failed to call WatchDeviceRemoved")?
166                    .map_err(|err| anyhow!("failed to watch for removed device: {:?}", err))?;
167                let token_id = response.token_id.ok_or_else(|| anyhow!("missing token_id"))?;
168                let mut devices = devices.lock().await;
169                let _ = devices.remove(&token_id);
170                for sender in event_senders.lock().await.iter_mut() {
171                    let _ = sender.unbounded_send(DeviceEvent::Removed(token_id));
172                }
173            }
174        }
175    }
176}
177
178pub struct RegistryDevice {
179    _info: DeviceInfo,
180    _proxy: fadevice::ObserverProxy,
181
182    /// If None, this device does not support signal processing.
183    pub signal_processing: Option<SignalProcessing>,
184}
185
186impl RegistryDevice {
187    pub fn new(info: DeviceInfo, proxy: fadevice::ObserverProxy) -> Self {
188        let is_signal_processing_supported = info.0.signal_processing_elements.is_some()
189            && info.0.signal_processing_topologies.is_some();
190        let signal_processing =
191            is_signal_processing_supported.then(|| SignalProcessing::new(proxy.clone()));
192
193        Self { _info: info, _proxy: proxy, signal_processing }
194    }
195}
196
197/// Client for the composed signal processing `Reader` in a `fuchsia.audio.device.Observer`.
198pub struct SignalProcessing {
199    proxy: fadevice::ObserverProxy,
200
201    element_states: Arc<Mutex<Option<BTreeMap<fadevice::ElementId, ElementState>>>>,
202    topology_id: Arc<Mutex<Option<fadevice::TopologyId>>>,
203
204    element_states_initialized: AsyncEvent,
205    topology_id_initialized: AsyncEvent,
206
207    _watch_element_states_task: Task<()>,
208    _watch_topology_task: Task<()>,
209}
210
211impl SignalProcessing {
212    fn new(proxy: fadevice::ObserverProxy) -> Self {
213        let element_states = Arc::new(Mutex::new(None));
214        let topology_id = Arc::new(Mutex::new(None));
215
216        let element_states_initialized = AsyncEvent::new();
217        let topology_id_initialized = AsyncEvent::new();
218
219        let watch_element_states_task = Task::spawn({
220            let proxy = proxy.clone();
221            let element_states = element_states.clone();
222            let element_states_initialized = element_states_initialized.clone();
223            async move {
224                if let Err(err) =
225                    watch_element_states(proxy, element_states, element_states_initialized.clone())
226                        .await
227                {
228                    error!(err:%; "Failed to watch Registry element states");
229                    // Watching the element states will fail if the device does not support signal
230                    // processing. In this case, mark the states as initialized so the getter can
231                    // return the initial None value.
232                    element_states_initialized.signal();
233                }
234            }
235        });
236
237        let watch_topology_task = Task::spawn({
238            let proxy = proxy.clone();
239            let topology_id = topology_id.clone();
240            let topology_id_initialized = topology_id_initialized.clone();
241            async move {
242                if let Err(err) =
243                    watch_topology(proxy, topology_id, topology_id_initialized.clone()).await
244                {
245                    error!(err:%; "Failed to watch Registry topology");
246                    // Watching the topology ID will fail if the device does not support signal
247                    // processing. In this case, mark the ID as initialized so the getter can
248                    // return the initial None value.
249                    topology_id_initialized.signal();
250                }
251            }
252        });
253
254        Self {
255            proxy,
256            element_states,
257            topology_id,
258            element_states_initialized,
259            topology_id_initialized,
260            _watch_element_states_task: watch_element_states_task,
261            _watch_topology_task: watch_topology_task,
262        }
263    }
264
265    /// Returns this device's signal processing elements, or `None` if the device does not support
266    /// signal processing.
267    pub async fn elements(&self) -> Result<Option<Vec<Element>>, Error> {
268        let response = self
269            .proxy
270            .get_elements()
271            .await
272            .context("failed to call GetElements")?
273            .map_err(|status| Status::from_raw(status));
274
275        if let Err(Status::NOT_SUPPORTED) = response {
276            return Ok(None);
277        }
278
279        let elements = response
280            .context("failed to get elements")?
281            .into_iter()
282            .map(TryInto::try_into)
283            .collect::<Result<Vec<_>, _>>()
284            .map_err(|err| anyhow!("Invalid element: {}", err))?;
285
286        Ok(Some(elements))
287    }
288
289    /// Returns this device's signal processing topologies, or `None` if the device does not
290    /// support signal processing.
291    pub async fn topologies(&self) -> Result<Option<Vec<Topology>>, Error> {
292        let response = self
293            .proxy
294            .get_topologies()
295            .await
296            .context("failed to call GetTopologies")?
297            .map_err(|status| Status::from_raw(status));
298
299        if let Err(Status::NOT_SUPPORTED) = response {
300            return Ok(None);
301        }
302
303        let topologies = response
304            .context("failed to get topologies")?
305            .into_iter()
306            .map(TryInto::try_into)
307            .collect::<Result<Vec<_>, _>>()
308            .map_err(|err| anyhow!("Invalid topology: {}", err))?;
309
310        Ok(Some(topologies))
311    }
312
313    /// Returns the current signal processing topology ID, or `None` if the device does not support
314    /// signal processing.
315    pub async fn topology_id(&self) -> Option<fadevice::TopologyId> {
316        self.topology_id_initialized.wait().await;
317        *self.topology_id.lock().await
318    }
319
320    /// Returns the state of the signal processing element with the given `element_id`.
321    ///
322    /// Returns None if there is no element with the given ID, or if the device does not support
323    /// signal processing.
324    pub async fn element_state(&self, element_id: fadevice::ElementId) -> Option<ElementState> {
325        self.element_states_initialized.wait().await;
326        self.element_states
327            .lock()
328            .await
329            .as_ref()
330            .and_then(|states| states.get(&element_id).cloned())
331    }
332
333    /// Returns states of all signal processing elements, or `None` if the device does not support
334    /// signal processing.
335    pub async fn element_states(&self) -> Option<BTreeMap<fadevice::ElementId, ElementState>> {
336        self.element_states_initialized.wait().await;
337        self.element_states.lock().await.clone()
338    }
339}
340
341/// Watches element state changes on a registry device and updates `element_states` with the
342/// current state for each element.
343///
344/// Signals `element_states_initialized` when `element_states` is populated
345/// with the initial set of states.
346async fn watch_element_states(
347    proxy: fadevice::ObserverProxy,
348    element_states: Arc<Mutex<Option<BTreeMap<fadevice::ElementId, ElementState>>>>,
349    element_states_initialized: AsyncEvent,
350) -> Result<(), Error> {
351    let mut element_states_initialized = Some(element_states_initialized);
352
353    let element_ids = {
354        let get_elements_response = proxy
355            .get_elements()
356            .await
357            .context("failed to call GetElements")?
358            .map_err(|status| Status::from_raw(status));
359
360        if let Err(Status::NOT_SUPPORTED) = get_elements_response {
361            element_states_initialized.take().unwrap().signal();
362            return Ok(());
363        }
364
365        get_elements_response
366            .context("failed to get elements")?
367            .into_iter()
368            .map(|element| element.id.ok_or_else(|| anyhow!("missing element 'id'")))
369            .collect::<Result<Vec<_>, _>>()?
370    };
371
372    // Contains element IDs for which we haven't received an initial state.
373    let mut uninitialized_element_ids = BTreeSet::from_iter(element_ids.iter().copied());
374
375    let state_streams = element_ids.into_iter().map(|element_id| {
376        HangingGetStream::new(proxy.clone(), move |p| p.watch_element_state(element_id))
377            .map(move |element_state_result| (element_id, element_state_result))
378    });
379
380    let mut all_states_stream = futures::stream::select_all(state_streams);
381
382    while let Some((element_id, element_state_result)) = all_states_stream.next().await {
383        let element_state: ElementState = element_state_result
384            .context("failed to call WatchElementState")?
385            .try_into()
386            .map_err(|err| anyhow!("Invalid element state: {}", err))?;
387        let mut element_states = element_states.lock().await;
388        let element_states_map = element_states.get_or_insert_with(|| BTreeMap::new());
389        let _ = element_states_map.insert(element_id, element_state);
390
391        // Signal `element_states_initialized` once all elements have initial states.
392        if element_states_initialized.is_some() {
393            let _ = uninitialized_element_ids.remove(&element_id);
394            if uninitialized_element_ids.is_empty() {
395                element_states_initialized.take().unwrap().signal();
396            }
397        }
398    }
399
400    Ok(())
401}
402
403/// Watches topology changes on a registry device and updates `topology_id`
404/// when the topology changes.
405///
406/// Signals `topology_id_initialized` when `topology_id` is populated
407/// with the initial topology.
408async fn watch_topology(
409    proxy: fadevice::ObserverProxy,
410    topology_id: Arc<Mutex<Option<fadevice::TopologyId>>>,
411    topology_id_initialized: AsyncEvent,
412) -> Result<(), Error> {
413    let mut topology_id_initialized = Some(topology_id_initialized);
414
415    let mut topology_stream =
416        HangingGetStream::new(proxy.clone(), fadevice::ObserverProxy::watch_topology);
417
418    while let Some(topology_result) = topology_stream.next().await {
419        let new_topology_id = topology_result.context("failed to call WatchTopology")?;
420
421        *topology_id.lock().await = Some(new_topology_id);
422
423        if let Some(topology_id_initialized) = topology_id_initialized.take() {
424            topology_id_initialized.signal();
425        }
426    }
427
428    Ok(())
429}
430
431#[cfg(test)]
432mod test {
433    use super::*;
434    use async_utils::hanging_get::server::{HangingGet, Publisher};
435    use fidl::endpoints::spawn_local_stream_handler;
436
437    type AddedResponse = fadevice::RegistryWatchDevicesAddedResponse;
438    type AddedResponder = fadevice::RegistryWatchDevicesAddedResponder;
439    type AddedNotifyFn = Box<dyn Fn(&AddedResponse, AddedResponder) -> bool>;
440    type AddedPublisher = Publisher<AddedResponse, AddedResponder, AddedNotifyFn>;
441
442    type RemovedResponse = fadevice::RegistryWatchDeviceRemovedResponse;
443    type RemovedResponder = fadevice::RegistryWatchDeviceRemovedResponder;
444    type RemovedNotifyFn = Box<dyn Fn(&RemovedResponse, RemovedResponder) -> bool>;
445    type RemovedPublisher = Publisher<RemovedResponse, RemovedResponder, RemovedNotifyFn>;
446
447    fn serve_registry(
448        initial_devices: Vec<fadevice::Info>,
449    ) -> (fadevice::RegistryProxy, AddedPublisher, RemovedPublisher) {
450        let initial_added_response =
451            AddedResponse { devices: Some(initial_devices), ..Default::default() };
452        let watch_devices_added_notify: AddedNotifyFn =
453            Box::new(|response, responder: AddedResponder| {
454                responder.send(Ok(response)).expect("failed to send response");
455                true
456            });
457        let mut added_broker = HangingGet::new(initial_added_response, watch_devices_added_notify);
458        let added_publisher = added_broker.new_publisher();
459
460        let watch_device_removed_notify: RemovedNotifyFn =
461            Box::new(|response, responder: RemovedResponder| {
462                responder.send(Ok(response)).expect("failed to send response");
463                true
464            });
465        let mut removed_broker = HangingGet::new_unknown_state(watch_device_removed_notify);
466        let removed_publisher = removed_broker.new_publisher();
467
468        let added_subscriber = Arc::new(Mutex::new(added_broker.new_subscriber()));
469        let removed_subscriber = Arc::new(Mutex::new(removed_broker.new_subscriber()));
470
471        let proxy = spawn_local_stream_handler(move |request| {
472            let added_subscriber = added_subscriber.clone();
473            let removed_subscriber = removed_subscriber.clone();
474            async move {
475                match request {
476                    fadevice::RegistryRequest::WatchDevicesAdded { responder } => {
477                        added_subscriber.lock().await.register(responder).unwrap()
478                    }
479                    fadevice::RegistryRequest::WatchDeviceRemoved { responder } => {
480                        removed_subscriber.lock().await.register(responder).unwrap()
481                    }
482                    _ => unimplemented!(),
483                }
484            }
485        });
486
487        (proxy, added_publisher, removed_publisher)
488    }
489
490    fn added_response(devices: Vec<fadevice::Info>) -> fadevice::RegistryWatchDevicesAddedResponse {
491        fadevice::RegistryWatchDevicesAddedResponse { devices: Some(devices), ..Default::default() }
492    }
493
494    fn removed_response(
495        token_id: fadevice::TokenId,
496    ) -> fadevice::RegistryWatchDeviceRemovedResponse {
497        fadevice::RegistryWatchDeviceRemovedResponse {
498            token_id: Some(token_id),
499            ..Default::default()
500        }
501    }
502
503    #[fuchsia::test]
504    async fn test_device_info() {
505        let initial_devices = vec![fadevice::Info { token_id: Some(1), ..Default::default() }];
506        let (registry_proxy, _added_publisher, _removed_publisher) =
507            serve_registry(initial_devices);
508        let registry = Registry::new(registry_proxy);
509
510        assert!(registry.device_info(1).await.is_some());
511        assert!(registry.device_info(2).await.is_none());
512    }
513
514    #[fuchsia::test]
515    async fn test_subscribe() {
516        let initial_devices = vec![];
517        let (registry_proxy, added_publisher, removed_publisher) = serve_registry(initial_devices);
518        let registry = Registry::new(registry_proxy);
519
520        registry.devices_initialized.wait().await;
521
522        let mut events_receiver = registry.subscribe().await;
523
524        // Publish a WatchDevicesAdded response with two devices and verify that we receive it.
525        added_publisher.set(added_response(vec![
526            fadevice::Info { token_id: Some(1), ..Default::default() },
527            fadevice::Info { token_id: Some(2), ..Default::default() },
528        ]));
529
530        // There should be two events, one for each device.
531        let events: Vec<_> = events_receiver.by_ref().take(2).collect().await;
532
533        let mut added_token_ids: Vec<_> = events
534            .iter()
535            .filter_map(|event| match event {
536                DeviceEvent::Added(info) => Some(info.token_id()),
537                _ => None,
538            })
539            .collect();
540        added_token_ids.sort();
541        assert_eq!(added_token_ids, vec![1, 2]);
542
543        // Publish a WatchDeviceRemoved response and verify that we receive it.
544        removed_publisher.set(removed_response(2));
545
546        // There should be one event.
547        let events: Vec<_> = events_receiver.take(1).collect().await;
548
549        let mut removed_token_ids: Vec<_> = events
550            .iter()
551            .filter_map(|event| match event {
552                DeviceEvent::Removed(token_id) => Some(*token_id),
553                _ => None,
554            })
555            .collect();
556        removed_token_ids.sort();
557        assert_eq!(removed_token_ids, vec![2]);
558    }
559}