1mod stream;
8#[cfg(test)]
9mod test_util;
10
11use std::cmp::Ordering;
12use std::collections::{HashMap, HashSet};
13
14use fidl_fuchsia_net::SocketAddress;
15use fidl_fuchsia_net_name::{
16 DhcpDnsServerSource, Dhcpv6DnsServerSource, DnsServerSource, DnsServer_, NdpDnsServerSource,
17 StaticDnsServerSource,
18};
19
20pub use self::stream::*;
21
22pub const DEFAULT_DNS_PORT: u16 = 53;
24
25#[derive(Default)]
27pub struct DnsServers {
28 default: Vec<DnsServer_>,
32
33 netstack: Vec<DnsServer_>,
35
36 dhcpv4: HashMap<u64, Vec<DnsServer_>>,
38
39 dhcpv6: HashMap<u64, Vec<DnsServer_>>,
41
42 ndp: HashMap<u64, Vec<DnsServer_>>,
44}
45
46impl DnsServers {
47 pub fn set_servers_from_source(
51 &mut self,
52 source: DnsServersUpdateSource,
53 servers: Vec<DnsServer_>,
54 ) {
55 let Self { default, netstack, dhcpv4, dhcpv6, ndp } = self;
56
57 match source {
58 DnsServersUpdateSource::Default => *default = servers,
59 DnsServersUpdateSource::Netstack => *netstack = servers,
60 DnsServersUpdateSource::Dhcpv4 { interface_id } => {
61 let _: Option<Vec<DnsServer_>> = if servers.is_empty() {
64 dhcpv4.remove(&interface_id)
65 } else {
66 dhcpv4.insert(interface_id, servers)
67 };
68 }
69 DnsServersUpdateSource::Dhcpv6 { interface_id } => {
70 let _: Option<Vec<DnsServer_>> = if servers.is_empty() {
73 dhcpv6.remove(&interface_id)
74 } else {
75 dhcpv6.insert(interface_id, servers)
76 };
77 }
78 DnsServersUpdateSource::Ndp { interface_id } => {
79 let _: Option<Vec<DnsServer_>> = if servers.is_empty() {
82 ndp.remove(&interface_id)
83 } else {
84 ndp.insert(interface_id, servers)
85 };
86 }
87 }
88 }
89
90 pub fn consolidated(&self) -> Vec<SocketAddress> {
111 self.consolidate_filter_map(|x| x.address)
112 }
113
114 pub fn consolidated_dns_servers(&self) -> Vec<DnsServer_> {
133 self.consolidate_filter_map(|x| Some(x))
134 }
135
136 fn consolidate_filter_map<T, F: Fn(DnsServer_) -> Option<T>>(&self, f: F) -> Vec<T> {
140 let Self { default, netstack, dhcpv4, dhcpv6, ndp } = self;
141 let mut servers = netstack
142 .iter()
143 .chain(dhcpv4.values().flatten())
144 .chain(ndp.values().flatten())
145 .chain(dhcpv6.values().flatten())
146 .cloned()
147 .collect::<Vec<_>>();
148 let () = servers.sort_by(Self::ordering);
154 let () = servers.extend(default.clone());
157 let mut addresses = HashSet::new();
158 let () = servers.retain(move |s| addresses.insert(s.address));
159 servers.into_iter().filter_map(f).collect()
160 }
161
162 fn ordering(a: &DnsServer_, b: &DnsServer_) -> Ordering {
176 let ordering = |source| match source {
177 Some(&DnsServerSource::Dhcp(DhcpDnsServerSource { source_interface: _, .. })) => 0,
178 Some(&DnsServerSource::Ndp(NdpDnsServerSource { source_interface: _, .. })) => 1,
179 Some(&DnsServerSource::Dhcpv6(Dhcpv6DnsServerSource {
180 source_interface: _, ..
181 })) => 2,
182 Some(&DnsServerSource::StaticSource(StaticDnsServerSource { .. })) | None => 3,
183 };
184 let a = ordering(a.source.as_ref());
185 let b = ordering(b.source.as_ref());
186 std::cmp::Ord::cmp(&a, &b)
187 }
188}
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193 use crate::test_util::constants::*;
194
195 #[test]
196 fn deduplicate_within_source() {
197 let servers = DnsServers {
199 default: vec![ndp_server(), ndp_server()],
200 netstack: vec![ndp_server(), static_server(), ndp_server(), static_server()],
201 dhcpv4: [
205 (DHCPV4_SERVER1_INTERFACE_ID, vec![dhcpv4_server1(), dhcpv4_server2()]),
206 (DHCPV4_SERVER2_INTERFACE_ID, vec![dhcpv4_server1(), dhcpv4_server2()]),
207 ]
208 .into_iter()
209 .collect(),
210 dhcpv6: [
211 (DHCPV6_SERVER1_INTERFACE_ID, vec![dhcpv6_server1(), dhcpv6_server2()]),
212 (DHCPV6_SERVER2_INTERFACE_ID, vec![dhcpv6_server1(), dhcpv6_server2()]),
213 ]
214 .into_iter()
215 .collect(),
216 ndp: [(NDP_SERVER_INTERFACE_ID, vec![ndp_server(), ndp_server()])]
217 .into_iter()
218 .collect(),
219 };
220 assert_eq!(
224 servers.consolidated(),
225 vec![
226 DHCPV4_SOURCE_SOCKADDR1,
227 DHCPV4_SOURCE_SOCKADDR2,
228 NDP_SOURCE_SOCKADDR,
229 DHCPV6_SOURCE_SOCKADDR1,
230 DHCPV6_SOURCE_SOCKADDR2,
231 STATIC_SOURCE_SOCKADDR,
232 ],
233 );
234 }
235
236 #[test]
237 fn default_low_prio() {
238 let servers = DnsServers {
242 default: vec![static_server(), dhcpv4_server1(), dhcpv6_server1()],
243 netstack: vec![static_server()],
244 dhcpv4: [
245 (DHCPV4_SERVER1_INTERFACE_ID, vec![dhcpv4_server1()]),
246 (DHCPV4_SERVER2_INTERFACE_ID, vec![dhcpv4_server1()]),
247 ]
248 .into_iter()
249 .collect(),
250 dhcpv6: [
251 (DHCPV6_SERVER1_INTERFACE_ID, vec![dhcpv6_server1()]),
252 (DHCPV6_SERVER2_INTERFACE_ID, vec![dhcpv6_server2()]),
253 ]
254 .into_iter()
255 .collect(),
256 ndp: [(NDP_SERVER_INTERFACE_ID, vec![ndp_server()])].into_iter().collect(),
257 };
258 let mut got = servers.consolidated();
261 let mut got = got.drain(..);
262 let want_dhcpv4 = [DHCPV4_SOURCE_SOCKADDR1];
263 assert_eq!(
264 HashSet::from_iter(got.by_ref().take(want_dhcpv4.len())),
265 HashSet::from(want_dhcpv4),
266 );
267
268 let want_ndp = [NDP_SOURCE_SOCKADDR];
269 assert_eq!(HashSet::from_iter(got.by_ref().take(want_ndp.len())), HashSet::from(want_ndp));
270
271 let want_dhcpv6 = [DHCPV6_SOURCE_SOCKADDR1, DHCPV6_SOURCE_SOCKADDR2];
272 assert_eq!(
273 HashSet::from_iter(got.by_ref().take(want_dhcpv6.len())),
274 HashSet::from(want_dhcpv6),
275 );
276
277 let want_rest = [STATIC_SOURCE_SOCKADDR];
278 assert_eq!(got.as_slice(), want_rest);
279 }
280
281 #[test]
282 fn deduplicate_across_sources() {
283 let dhcpv6_with_ndp_address = || DnsServer_ {
288 address: Some(NDP_SOURCE_SOCKADDR),
289 source: Some(DnsServerSource::Dhcpv6(Dhcpv6DnsServerSource {
290 source_interface: Some(DHCPV6_SERVER1_INTERFACE_ID),
291 ..Default::default()
292 })),
293 ..Default::default()
294 };
295 let mut dhcpv6 = HashMap::new();
296 assert_matches::assert_matches!(
297 dhcpv6.insert(
298 DHCPV6_SERVER1_INTERFACE_ID,
299 vec![dhcpv6_with_ndp_address(), dhcpv6_server1()]
300 ),
301 None
302 );
303 let mut servers = DnsServers {
304 default: vec![],
305 netstack: vec![dhcpv4_server1(), static_server()],
306 dhcpv4: [(DHCPV4_SERVER1_INTERFACE_ID, vec![dhcpv4_server1()])].into_iter().collect(),
307 dhcpv6: [(
308 DHCPV6_SERVER1_INTERFACE_ID,
309 vec![dhcpv6_with_ndp_address(), dhcpv6_server1()],
310 )]
311 .into_iter()
312 .collect(),
313 ndp: [(NDP_SERVER_INTERFACE_ID, vec![ndp_server(), dhcpv6_with_ndp_address()])]
314 .into_iter()
315 .collect(),
316 };
317 let expected_servers =
318 vec![dhcpv4_server1(), ndp_server(), dhcpv6_server1(), static_server()];
319 assert_eq!(servers.consolidate_filter_map(Some), expected_servers);
320 let expected_sockaddrs = vec![
321 DHCPV4_SOURCE_SOCKADDR1,
322 NDP_SOURCE_SOCKADDR,
323 DHCPV6_SOURCE_SOCKADDR1,
324 STATIC_SOURCE_SOCKADDR,
325 ];
326 assert_eq!(servers.consolidated(), expected_sockaddrs);
327 servers.netstack = vec![dhcpv4_server1(), static_server(), dhcpv6_with_ndp_address()];
328 assert_eq!(servers.consolidate_filter_map(Some), expected_servers);
329 assert_eq!(servers.consolidated(), expected_sockaddrs);
330
331 let ndp_with_dhcpv6_sockaddr1 = || DnsServer_ {
334 address: Some(DHCPV6_SOURCE_SOCKADDR1),
335 source: Some(DnsServerSource::Ndp(NdpDnsServerSource {
336 source_interface: Some(NDP_SERVER_INTERFACE_ID),
337 ..Default::default()
338 })),
339 ..Default::default()
340 };
341
342 let mut dhcpv6 = HashMap::new();
343 assert_matches::assert_matches!(
344 dhcpv6.insert(DHCPV6_SERVER1_INTERFACE_ID, vec![dhcpv6_server1()]),
345 None
346 );
347 assert_matches::assert_matches!(
348 dhcpv6.insert(DHCPV6_SERVER2_INTERFACE_ID, vec![dhcpv6_server2()]),
349 None
350 );
351 let mut servers = DnsServers {
352 default: vec![],
353 netstack: vec![static_server()],
354 dhcpv4: Default::default(),
355 dhcpv6,
356 ndp: [(NDP_SERVER_INTERFACE_ID, vec![ndp_with_dhcpv6_sockaddr1()])]
357 .into_iter()
358 .collect(),
359 };
360 let expected_servers = vec![ndp_with_dhcpv6_sockaddr1(), dhcpv6_server2(), static_server()];
361 assert_eq!(servers.consolidate_filter_map(Some), expected_servers);
362 let expected_sockaddrs =
363 vec![DHCPV6_SOURCE_SOCKADDR1, DHCPV6_SOURCE_SOCKADDR2, STATIC_SOURCE_SOCKADDR];
364 assert_eq!(servers.consolidated(), expected_sockaddrs);
365 servers.netstack = vec![static_server(), ndp_with_dhcpv6_sockaddr1()];
366 assert_eq!(servers.consolidate_filter_map(Some), expected_servers);
367 assert_eq!(servers.consolidated(), expected_sockaddrs);
368 }
369
370 #[test]
371 fn test_dns_servers_ordering() {
372 assert_eq!(DnsServers::ordering(&ndp_server(), &ndp_server()), Ordering::Equal);
373 assert_eq!(DnsServers::ordering(&dhcpv4_server1(), &dhcpv4_server1()), Ordering::Equal);
374 assert_eq!(DnsServers::ordering(&dhcpv6_server1(), &dhcpv6_server1()), Ordering::Equal);
375 assert_eq!(DnsServers::ordering(&static_server(), &static_server()), Ordering::Equal);
376 assert_eq!(
377 DnsServers::ordering(&unspecified_source_server(), &unspecified_source_server()),
378 Ordering::Equal
379 );
380 assert_eq!(
381 DnsServers::ordering(&static_server(), &unspecified_source_server()),
382 Ordering::Equal
383 );
384
385 let servers = [
386 dhcpv4_server1(),
387 ndp_server(),
388 dhcpv6_server1(),
389 static_server(),
390 unspecified_source_server(),
391 ];
392 for (i, a) in servers[..servers.len() - 2].iter().enumerate() {
395 for b in servers[i + 1..].iter() {
396 assert_eq!(DnsServers::ordering(a, b), Ordering::Less);
397 }
398 }
399
400 let mut servers = vec![dhcpv6_server1(), dhcpv4_server1(), static_server(), ndp_server()];
401 servers.sort_by(DnsServers::ordering);
402 assert_eq!(
403 servers,
404 vec![dhcpv4_server1(), ndp_server(), dhcpv6_server1(), static_server()]
405 );
406 }
407}