socket_proxy/
dns_watcher.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Implements the fuchsia.net.policy.socketproxy.DnsServerWatcher service.
6
7use anyhow::{Context, Error};
8use fidl::endpoints::{ControlHandle as _, RequestStream as _, Responder as _};
9use fidl_fuchsia_net_policy_socketproxy::{self as fnp_socketproxy, DnsServerList};
10use fuchsia_inspect_derive::{IValue, Inspect, Unit};
11use futures::channel::mpsc;
12use futures::lock::Mutex;
13use futures::{StreamExt, TryStreamExt};
14use log::{info, warn};
15use std::sync::Arc;
16
17#[derive(Unit, Debug, Default)]
18struct DnsServerWatcherState {
19    #[inspect(skip)]
20    server_list: Vec<DnsServerList>,
21    #[inspect(skip)]
22    last_sent: Option<Vec<DnsServerList>>,
23    #[inspect(skip)]
24    queued_responder: Option<fnp_socketproxy::DnsServerWatcherWatchServersResponder>,
25
26    updates_seen: u32,
27    updates_sent: u32,
28}
29
30/// A wrapper around the fuchsia.net.policy.socketproxy.DnsServerWatcher service
31/// that tracks when a DnsServerList update needs to be sent.
32#[derive(Inspect, Debug, Clone)]
33pub(crate) struct DnsServerWatcher {
34    #[inspect(forward)]
35    state: Arc<Mutex<IValue<DnsServerWatcherState>>>,
36    dns_rx: Arc<Mutex<mpsc::Receiver<Vec<fnp_socketproxy::DnsServerList>>>>,
37}
38
39impl DnsServerWatcher {
40    /// Create a new DnsServerWatcher.
41    pub(crate) fn new(
42        dns_rx: Arc<Mutex<mpsc::Receiver<Vec<fnp_socketproxy::DnsServerList>>>>,
43    ) -> Self {
44        Self { dns_rx, state: Default::default() }
45    }
46
47    /// Runs the fuchsia.net.policy.socketproxy.DnsServerWatcher service.
48    pub(crate) async fn run<'a>(
49        &self,
50        stream: fnp_socketproxy::DnsServerWatcherRequestStream,
51    ) -> Result<(), Error> {
52        let mut state = match self.state.try_lock() {
53            Some(o) => o,
54            None => {
55                warn!("Only one connection to DnsServerWatcher is allowed at a time");
56                stream.control_handle().shutdown_with_epitaph(fidl::Status::ACCESS_DENIED);
57                return Ok(());
58            }
59        };
60        let mut dns_rx = self.dns_rx.lock().await;
61        info!("Starting fuchsia.net.policy.socketproxy.DnsServerWatcher server");
62        let mut stream = stream.map(|result| result.context("failed request")).fuse();
63
64        loop {
65            futures::select! {
66                request = stream.try_next() => match request? {
67                    Some(fnp_socketproxy::DnsServerWatcherRequest::WatchServers { responder }) => {
68                        let mut state = state.as_mut();
69                        if state.queued_responder.is_some() {
70                            warn!("Only one call to watch server may be active at once");
71                            responder
72                                .control_handle()
73                                .shutdown_with_epitaph(fidl::Status::ACCESS_DENIED);
74                        } else {
75                            state.queued_responder = Some(responder);
76                            state.maybe_respond()?;
77                        }
78                    },
79                    None => {}
80                },
81                dns_update = dns_rx.select_next_some() => {
82                    let mut state = state.as_mut();
83                    state.updates_seen += 1;
84                    state.server_list = dns_update;
85                    state.maybe_respond()?;
86                }
87            }
88        }
89    }
90}
91
92impl DnsServerWatcherState {
93    fn maybe_respond(&mut self) -> Result<(), Error> {
94        if self.last_sent.as_ref() != Some(&self.server_list) {
95            if let Some(responder) = self.queued_responder.take() {
96                info!("Sending DNS update to client: {}", self.server_list.len());
97                responder.send(&self.server_list)?;
98                self.updates_sent += 1;
99                self.last_sent = Some(self.server_list.clone());
100            }
101        }
102        Ok(())
103    }
104}
105
106#[cfg(test)]
107mod test {
108    use super::*;
109    use assert_matches::assert_matches;
110    use diagnostics_assertions::assert_data_tree;
111    use fuchsia_component::server::ServiceFs;
112    use fuchsia_component_test::{
113        Capability, ChildOptions, LocalComponentHandles, RealmBuilder, RealmInstance, Ref, Route,
114    };
115    use fuchsia_inspect_derive::WithInspect;
116    use futures::SinkExt as _;
117    use futures::channel::mpsc::{Receiver, Sender};
118    use pretty_assertions::assert_eq;
119
120    enum IncomingService {
121        DnsServerWatcher(fnp_socketproxy::DnsServerWatcherRequestStream),
122    }
123
124    async fn run_registry(
125        handles: LocalComponentHandles,
126        dns_rx: Arc<Mutex<Receiver<Vec<fnp_socketproxy::DnsServerList>>>>,
127    ) -> Result<(), Error> {
128        let mut fs = ServiceFs::new();
129        let _ = fs.dir("svc").add_fidl_service(IncomingService::DnsServerWatcher);
130        let _ = fs.serve_connection(handles.outgoing_dir)?;
131
132        let watcher = DnsServerWatcher::new(dns_rx)
133            .with_inspect(fuchsia_inspect::component::inspector().root(), "dns_watcher")?;
134
135        fs.for_each_concurrent(0, |IncomingService::DnsServerWatcher(stream)| {
136            let watcher = watcher.clone();
137            async move {
138                watcher
139                    .run(stream)
140                    .await
141                    .context("Failed to serve request stream")
142                    .unwrap_or_else(|e| eprintln!("Error encountered: {e:?}"))
143            }
144        })
145        .await;
146
147        Ok(())
148    }
149
150    async fn setup_test()
151    -> Result<(RealmInstance, Sender<Vec<fnp_socketproxy::DnsServerList>>), Error> {
152        let builder = RealmBuilder::new().await?;
153        let (dns_tx, dns_rx) = mpsc::channel(1);
154        let dns_rx = Arc::new(Mutex::new(dns_rx));
155        let registry = builder
156            .add_local_child(
157                "dns_watcher",
158                {
159                    let dns_rx = dns_rx.clone();
160                    move |handles: LocalComponentHandles| {
161                        Box::pin(run_registry(handles, dns_rx.clone()))
162                    }
163                },
164                ChildOptions::new(),
165            )
166            .await?;
167
168        builder
169            .add_route(
170                Route::new()
171                    .capability(Capability::protocol::<fnp_socketproxy::DnsServerWatcherMarker>())
172                    .from(&registry)
173                    .to(Ref::parent()),
174            )
175            .await?;
176
177        let realm = builder.build().await?;
178
179        Ok((realm, dns_tx))
180    }
181
182    #[fuchsia::test]
183    async fn test_normal_operation() -> Result<(), Error> {
184        let (realm, mut dns_tx) = setup_test().await?;
185
186        let dns_server_watcher: fnp_socketproxy::DnsServerWatcherProxy = realm
187            .root
188            .connect_to_protocol_at_exposed_dir()
189            .context("While connecting to DnsServerWatcher")?;
190
191        // Initial watch should return immediately
192        assert_eq!(dns_server_watcher.watch_servers().await?, vec![]);
193
194        // Send a new DNS update
195        let (send_result, watch_result) = futures::future::join(
196            dns_tx.send(vec![DnsServerList { source_network_id: Some(0), ..Default::default() }]),
197            dns_server_watcher.watch_servers(),
198        )
199        .await;
200
201        assert_matches!(send_result, Ok(()));
202        assert_eq!(
203            watch_result?,
204            vec![DnsServerList { source_network_id: Some(0), ..Default::default() }]
205        );
206
207        assert_data_tree!(fuchsia_inspect::component::inspector(), root: {
208            dns_watcher: {
209                updates_seen: 1u64,
210                updates_sent: 2u64,
211            },
212        });
213
214        Ok(())
215    }
216
217    #[fuchsia::test]
218    async fn test_duplicate_list() -> Result<(), Error> {
219        let (realm, mut dns_tx) = setup_test().await?;
220        let dns_server_watcher: fnp_socketproxy::DnsServerWatcherProxy = realm
221            .root
222            .connect_to_protocol_at_exposed_dir()
223            .context("While connecting to DnsServerWatcher")?;
224
225        // Initial watch should return immediately
226        assert_eq!(dns_server_watcher.watch_servers().await?, vec![]);
227
228        let server_list = vec![DnsServerList { source_network_id: Some(0), ..Default::default() }];
229
230        let mut dns_tx2 = dns_tx.clone();
231        let mut dns_tx3 = dns_tx.clone();
232        let (watch_result, s1, s2, s3) = futures::join!(
233            dns_server_watcher.watch_servers(),
234            dns_tx.send(server_list.clone()),
235            dns_tx2.send(server_list.clone()),
236            dns_tx3.send(server_list.clone()),
237        );
238
239        assert_matches!(s1, Ok(()));
240        assert_matches!(s2, Ok(()));
241        assert_matches!(s3, Ok(()));
242        assert_eq!(watch_result?, server_list);
243
244        // Send a new (distinct) DNS update
245        let (send_result, watch_result) = futures::future::join(
246            dns_tx.send(vec![DnsServerList { source_network_id: Some(1), ..Default::default() }]),
247            dns_server_watcher.watch_servers(),
248        )
249        .await;
250        assert_matches!(send_result, Ok(()));
251
252        // We expect that this watch should get the new server list, not one of
253        // the old duplicate ones.
254        assert_eq!(
255            watch_result?,
256            vec![DnsServerList { source_network_id: Some(1), ..Default::default() }]
257        );
258
259        assert_data_tree!(fuchsia_inspect::component::inspector(), root: {
260            dns_watcher: {
261                updates_seen: 4u64,
262                updates_sent: 3u64,
263            },
264        });
265
266        Ok(())
267    }
268
269    #[fuchsia::test]
270    async fn test_duplicate_watch() -> Result<(), Error> {
271        let (realm, _dns_tx) = setup_test().await?;
272
273        let dns_server_watcher: fnp_socketproxy::DnsServerWatcherProxy = realm
274            .root
275            .connect_to_protocol_at_exposed_dir()
276            .context("While connecting to DnsServerWatcher")?;
277
278        // Initial watch should return immediately
279        assert_eq!(dns_server_watcher.watch_servers().await?, vec![]);
280
281        let watch1 = dns_server_watcher.watch_servers();
282        let watch2 = dns_server_watcher.watch_servers();
283
284        // Two simultaneous calls to watch_servers is invalid and will cause the
285        // watcher channel to be closed.
286        assert_matches!(
287            futures::future::join(watch1, watch2).await,
288            (
289                Err(fidl::Error::ClientChannelClosed { status: fidl::Status::ACCESS_DENIED, .. }),
290                Err(fidl::Error::ClientChannelClosed { status: fidl::Status::ACCESS_DENIED, .. })
291            )
292        );
293
294        assert_data_tree!(fuchsia_inspect::component::inspector(), root: {
295            dns_watcher: {
296                updates_seen: 0u64,
297                updates_sent: 1u64,
298            },
299        });
300
301        Ok(())
302    }
303}