1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
// Copyright 2024 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

use crate::device::Info as DeviceInfo;
use anyhow::{anyhow, Context, Error};
use async_utils::event::Event as AsyncEvent;
use async_utils::hanging_get::client::HangingGetStream;
use fidl_fuchsia_audio_device as fadevice;
use fuchsia_async::Task;
use futures::{lock::Mutex, StreamExt};
use std::collections::BTreeMap;
use std::sync::Arc;

pub struct Registry {
    devices: Arc<Mutex<BTreeMap<fadevice::TokenId, DeviceInfo>>>,
    devices_initialized: AsyncEvent,
    _watch_devices_task: Task<Result<(), Error>>,
}

impl Registry {
    pub fn new(proxy: fadevice::RegistryProxy) -> Self {
        let devices = Arc::new(Mutex::new(BTreeMap::new()));
        let devices_initialized = AsyncEvent::new();
        let watch_devices_task =
            Task::spawn(watch_devices(proxy, devices.clone(), devices_initialized.clone()));
        Self { devices, devices_initialized, _watch_devices_task: watch_devices_task }
    }

    /// Returns information about the device with the given `token_id`.
    ///
    /// Returns None if there is no device with the given ID.
    pub async fn get(&self, token_id: fadevice::TokenId) -> Option<DeviceInfo> {
        self.devices_initialized.wait().await;
        self.devices.lock().await.get(&token_id).cloned()
    }

    /// Returns information about all devices in the registry.
    pub async fn get_all(&self) -> BTreeMap<fadevice::TokenId, DeviceInfo> {
        self.devices_initialized.wait().await;
        self.devices.lock().await.clone()
    }
}

/// Watches devices added to and removed from the registry and updates
/// `devices` with the current state.
///
/// Signals `devices_initialized` when `devices` is populated with the initial
/// set of devices.
async fn watch_devices(
    proxy: fadevice::RegistryProxy,
    devices: Arc<Mutex<BTreeMap<fadevice::TokenId, DeviceInfo>>>,
    devices_initialized: AsyncEvent,
) -> Result<(), Error> {
    let mut devices_initialized = Some(devices_initialized);

    let mut devices_added_stream =
        HangingGetStream::new(proxy.clone(), fadevice::RegistryProxy::watch_devices_added);
    let mut device_removed_stream =
        HangingGetStream::new(proxy, fadevice::RegistryProxy::watch_device_removed);

    loop {
        futures::select! {
            added = devices_added_stream.select_next_some() => {
                let response = added
                    .context("failed to call WatchDevicesAdded")?
                    .map_err(|err| anyhow!("failed to watch for added devices: {:?}", err))?;
                let added_devices = response.devices.ok_or_else(|| anyhow!("missing devices"))?;

                let mut devices = devices.lock().await;
                for new_device in added_devices.into_iter() {
                    let token_id = new_device.token_id.ok_or_else(|| anyhow!("device info missing token_id"))?;
                    let _ = devices.insert(token_id, DeviceInfo::from(new_device));
                }

                if let Some(devices_initialized) = devices_initialized.take() {
                    devices_initialized.signal();
                }
            },
            removed = device_removed_stream.select_next_some() => {
                let response = removed
                    .context("failed to call WatchDeviceRemoved")?
                    .map_err(|err| anyhow!("failed to watch for removed device: {:?}", err))?;
                let token_id = response.token_id.ok_or_else(|| anyhow!("missing token_id"))?;
                let mut devices = devices.lock().await;
                let _ = devices.remove(&token_id);
            }
        }
    }
}

#[cfg(test)]
mod test {
    use super::*;
    use async_utils::hanging_get::server::{HangingGet, Publisher};
    use fidl::endpoints::spawn_local_stream_handler;

    type AddedResponse = fadevice::RegistryWatchDevicesAddedResponse;
    type AddedResponder = fadevice::RegistryWatchDevicesAddedResponder;
    type AddedNotifyFn = Box<dyn Fn(&AddedResponse, AddedResponder) -> bool>;
    type AddedPublisher = Publisher<AddedResponse, AddedResponder, AddedNotifyFn>;

    type RemovedResponse = fadevice::RegistryWatchDeviceRemovedResponse;
    type RemovedResponder = fadevice::RegistryWatchDeviceRemovedResponder;
    type RemovedNotifyFn = Box<dyn Fn(&RemovedResponse, RemovedResponder) -> bool>;
    type RemovedPublisher = Publisher<RemovedResponse, RemovedResponder, RemovedNotifyFn>;

    fn serve_registry(
        initial_devices: Vec<fadevice::Info>,
    ) -> (fadevice::RegistryProxy, AddedPublisher, RemovedPublisher) {
        let initial_added_response =
            AddedResponse { devices: Some(initial_devices), ..Default::default() };
        let watch_devices_added_notify: AddedNotifyFn =
            Box::new(|response, responder: AddedResponder| {
                responder.send(Ok(response)).expect("failed to send response");
                true
            });
        let added_broker = HangingGet::new(initial_added_response, watch_devices_added_notify);
        let added_publisher = added_broker.new_publisher();

        let watch_device_removed_notify: RemovedNotifyFn =
            Box::new(|response, responder: RemovedResponder| {
                responder.send(Ok(response)).expect("failed to send response");
                true
            });
        let removed_broker = HangingGet::new_unknown_state(watch_device_removed_notify);
        let removed_publisher = removed_broker.new_publisher();

        let added_broker = Arc::new(Mutex::new(added_broker));
        let removed_broker = Arc::new(Mutex::new(removed_broker));

        let proxy = spawn_local_stream_handler(move |request| {
            let added_broker = added_broker.clone();
            let removed_broker = removed_broker.clone();
            async move {
                let added_subscriber = added_broker.lock().await.new_subscriber();
                let removed_subscriber = removed_broker.lock().await.new_subscriber();
                match request {
                    fadevice::RegistryRequest::WatchDevicesAdded { responder } => {
                        added_subscriber.register(responder).unwrap()
                    }
                    fadevice::RegistryRequest::WatchDeviceRemoved { responder } => {
                        removed_subscriber.register(responder).unwrap()
                    }
                    _ => unimplemented!(),
                }
            }
        })
        .unwrap();

        (proxy, added_publisher, removed_publisher)
    }

    #[fuchsia::test]
    async fn test_get() {
        let devices = vec![fadevice::Info { token_id: Some(1), ..Default::default() }];
        let (registry_proxy, _added_publisher, _removed_publisher) = serve_registry(devices);
        let registry = Registry::new(registry_proxy);

        assert!(registry.get(1).await.is_some());
        assert!(registry.get(2).await.is_none());
    }
}