openthread_fuchsia/backing/
resolver.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use super::*;
6use anyhow::Context as _;
7use fidl_fuchsia_net_name::DnsServerWatcherMarker;
8use fuchsia_sync::Mutex;
9use openthread::ot::DnsUpstream;
10use openthread_sys::*;
11use std::collections::HashMap;
12use std::hash::{Hash, Hasher};
13use std::net::{Ipv6Addr, SocketAddr};
14use std::sync::Arc;
15use std::task::{Context, Poll, Waker};
16
17const MAX_DNS_RESPONSE_SIZE: usize = 2048;
18
19struct DnsUpstreamQueryRefWrapper(&'static ot::PlatDnsUpstreamQuery);
20
21impl DnsUpstreamQueryRefWrapper {
22    fn as_ptr(&self) -> *const ot::PlatDnsUpstreamQuery {
23        std::ptr::from_ref(self.0)
24    }
25}
26
27impl PartialEq for DnsUpstreamQueryRefWrapper {
28    fn eq(&self, other: &Self) -> bool {
29        self.as_ptr().eq(&other.as_ptr())
30    }
31}
32
33impl Eq for DnsUpstreamQueryRefWrapper {}
34
35impl Hash for DnsUpstreamQueryRefWrapper {
36    fn hash<H>(&self, state: &mut H)
37    where
38        H: Hasher,
39    {
40        self.as_ptr().hash(state);
41    }
42}
43
44struct Transaction {
45    // This field is set but never accessed directly, so we need to silence this warning
46    // so that we can compile.
47    #[allow(unused)]
48    // The task that performs socket poll and forwards the DNS reply from socket.
49    task: fasync::Task<Result<(), anyhow::Error>>,
50    // Receive the DNS reply from the `task` which stores the corresponding sender.
51    receiver: fmpsc::UnboundedReceiver<(DnsUpstreamQueryRefWrapper, Vec<u8>)>,
52}
53
54struct LocalDnsServerList {
55    // A local copy of the DNS server list
56    dns_server_list: Vec<fidl_fuchsia_net_name::DnsServer_>,
57    // This field is set but never accessed directly, so we need to silence this warning
58    // so that we can compile
59    #[allow(unused)]
60    // The task that awaits on the DNS server list change.
61    task: fasync::Task<Result<(), anyhow::Error>>,
62    // Receive the DNS server list from the `task` which stores the corresponding sender.
63    receiver: fmpsc::UnboundedReceiver<Vec<fidl_fuchsia_net_name::DnsServer_>>,
64}
65
66pub(crate) struct Resolver {
67    // The Map that uses the `DnsUpstreamQueryRefWrapper` as key to quickly locate the Transaction
68    transactions_map: Arc<Mutex<HashMap<DnsUpstreamQueryRefWrapper, Transaction>>>,
69    // Maintains a local DNS record for immediately sending out the DNS upstream query.
70    local_dns_record: RefCell<Option<LocalDnsServerList>>,
71    waker: Cell<Option<Waker>>,
72}
73
74impl Resolver {
75    pub fn new() -> Resolver {
76        if let Ok(proxy) =
77            fuchsia_component::client::connect_to_protocol::<DnsServerWatcherMarker>()
78        {
79            let (mut sender, receiver) = fmpsc::unbounded();
80
81            // Create a future that await for the latest DNS server list, and forward it to the
82            // corresponding receiver. The future is executed in the task in `LocalDnsServerList`.
83            let dns_list_watcher_fut = async move {
84                loop {
85                    let vec = proxy.watch_servers().await?;
86                    info!(tag = "resolver"; "getting latest DNS server list: {:?}", vec);
87                    if let Err(e) = sender.send(vec).await {
88                        warn!(
89                            tag = "resolver";
90                            "error when sending out latest dns list to process_poll, {:?}", e
91                        );
92                    }
93                }
94            };
95            Resolver {
96                transactions_map: Default::default(),
97                waker: Default::default(),
98                local_dns_record: RefCell::new(Some(LocalDnsServerList {
99                    dns_server_list: Vec::new(),
100                    task: fuchsia_async::Task::spawn(dns_list_watcher_fut),
101                    receiver,
102                })),
103            }
104        } else {
105            warn!(
106                tag = "resolver";
107                "failed to connect to `DnsServerWatcherMarker`, \
108                         DNS upstream query will not be supported"
109            );
110            Resolver {
111                transactions_map: Arc::new(Mutex::new(HashMap::new())),
112                waker: Cell::new(None),
113                local_dns_record: RefCell::new(None),
114            }
115        }
116    }
117
118    pub fn is_upstream_query_available(&self) -> bool {
119        if let Some(local_dns_record) = self.local_dns_record.borrow().as_ref() {
120            !local_dns_record.dns_server_list.is_empty()
121        } else {
122            false
123        }
124    }
125
126    pub fn process_poll_resolver(&self, instance: &ot::Instance, cx: &mut Context<'_>) {
127        // Update the waker so that we can later signal when we need to be polled again
128        self.waker.replace(Some(cx.waker().clone()));
129
130        // Poll the DNS server list task
131        if let Some(local_dns_record) = self.local_dns_record.borrow_mut().as_mut() {
132            while let Poll::Ready(Some(dns_server_list)) =
133                local_dns_record.receiver.poll_next_unpin(cx)
134            {
135                // DNS server watcher proxy returns the new DNS server list when something changed
136                // in netstack. The outdated list should be replaced by the new one.
137                local_dns_record.dns_server_list = dns_server_list;
138            }
139        }
140
141        let mut remove_key_vec = Vec::new();
142        // Poll the socket in each transaction. If a response is ready, forward it to the OpenThread
143        // and remove the corresponding transaction.
144        for (_, transaction) in self.transactions_map.lock().iter_mut() {
145            while let Poll::Ready(Some((context, message_vec))) =
146                transaction.receiver.poll_next_unpin(cx)
147            {
148                if let Ok(mut message) =
149                    ot::Message::udp_new(instance, None).context("cannot create UDP message")
150                {
151                    match message.append(&message_vec) {
152                        Ok(_) => {
153                            instance.plat_dns_upstream_query_done(context.0, message);
154                        }
155                        Err(e) => {
156                            warn!(tag = "resolver"; "failed to append to `ot::Message`: {}", e);
157                        }
158                    }
159                } else {
160                    warn!(
161                        tag = "resolver";
162                        "failed to create `ot::Message`, drop the upstream DNS response"
163                    );
164                }
165                remove_key_vec.push(context);
166            }
167        }
168
169        // cancel the transaction
170        for key in remove_key_vec {
171            self.transactions_map.lock().remove(&key);
172        }
173    }
174
175    fn on_start_dns_upstream_query<'a>(
176        &self,
177        _instance: &ot::Instance,
178        thread_context: &'static ot::PlatDnsUpstreamQuery,
179        dns_query: &ot::Message<'_>,
180    ) {
181        let sockaddr = SocketAddr::new(Ipv6Addr::UNSPECIFIED.into(), 53);
182        let socket = match fuchsia_async::net::UdpSocket::bind(&sockaddr) {
183            Ok(socket) => socket,
184            Err(_) => {
185                warn!(
186                    tag = "resolver";
187                    "on_start_dns_upstream_query() failed to create UDP socket, ignoring the query"
188                );
189                return;
190            }
191        };
192
193        let query_bytes = dns_query.to_vec();
194
195        // Get the DNS server list, and send out the query to all the available DNS servers.
196        if let Some(local_dns_record) = self.local_dns_record.borrow().as_ref() {
197            for dns_server in &local_dns_record.dns_server_list {
198                if let Some(address) = dns_server.address {
199                    match address {
200                        fidl_fuchsia_net::SocketAddress::Ipv4(ipv4_sock_addr) => {
201                            let sock_addr = SocketAddr::new(
202                                std::net::IpAddr::V4(std::net::Ipv4Addr::from(
203                                    ipv4_sock_addr.address.addr,
204                                )),
205                                ipv4_sock_addr.port,
206                            );
207                            info!(
208                                tag = "resolver";
209                                "sending DNS query to IPv4 server {}", sock_addr
210                            );
211                            if let Some(Err(e)) =
212                                socket.send_to(&query_bytes, sock_addr).now_or_never()
213                            {
214                                warn!(
215                                    tag = "resolver";
216                                    "Failed to send DNS query to IPv4 server {}: {}", sock_addr, e
217                                );
218                            }
219                        }
220                        fidl_fuchsia_net::SocketAddress::Ipv6(ipv6_sock_addr) => {
221                            let sock_addr = SocketAddr::new(
222                                std::net::IpAddr::V6(std::net::Ipv6Addr::from(
223                                    ipv6_sock_addr.address.addr,
224                                )),
225                                ipv6_sock_addr.port,
226                            );
227
228                            info!(
229                                tag = "resolver";
230                                "sending DNS query to IPv6 server {}", sock_addr
231                            );
232                            if let Some(Err(e)) =
233                                socket.send_to(&query_bytes, sock_addr).now_or_never()
234                            {
235                                warn!(
236                                    tag = "resolver";
237                                    "Failed to send DNS query to IPv6 server {}: {}", sock_addr, e
238                                );
239                            }
240                        }
241                    }
242                }
243            }
244
245            let (mut sender, receiver) = fmpsc::unbounded();
246
247            // Create a poll_fn for the socket that can be await on
248            let receive_from_fut = futures::future::poll_fn(move |cx| {
249                let mut buffer = [0u8; MAX_DNS_RESPONSE_SIZE];
250                match socket.async_recv_from(&mut buffer, cx) {
251                    Poll::Ready(Ok((len, sockaddr))) => {
252                        let message = buffer[..len].to_vec();
253                        Poll::Ready(Ok((message, sockaddr)))
254                    }
255                    Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
256                    Poll::Pending => Poll::Pending,
257                }
258            });
259
260            // Create a future that forward the DNS reply from socket to process_poll
261            let fut = async move {
262                let (message_vec, sockaddr) =
263                    receive_from_fut.await.context("error receiving from dns upstream socket")?;
264
265                info!(
266                    tag = "resolver";
267                    "Incoming {} bytes DNS response from {:?}",
268                    message_vec.len(),
269                    sockaddr
270                );
271                if let Err(e) =
272                    sender.send((DnsUpstreamQueryRefWrapper(thread_context), message_vec)).await
273                {
274                    warn!(
275                        tag = "resolver";
276                        "error when sending out dns upstream reply to process_poll, {:?}", e
277                    );
278                }
279                Ok(())
280            };
281
282            // Socket and the sender is owned by the task now
283            let transaction = Transaction { task: fuchsia_async::Task::spawn(fut), receiver };
284
285            self.transactions_map
286                .lock()
287                .insert(DnsUpstreamQueryRefWrapper(thread_context), transaction);
288        } else {
289            warn!(
290                tag = "resolver";
291                "on_start_dns_upstream_query() failed to get local_dns_record, ignoring the query"
292            );
293        }
294
295        // Trigger the waker so that our poll method gets called by the executor
296        self.waker.replace(None).and_then(|waker| {
297            waker.wake();
298            Some(())
299        });
300    }
301
302    // Cancel the pending query
303    fn on_cancel_dns_upstream_query(
304        &self,
305        _instance: &ot::Instance,
306        thread_context: &'static ot::PlatDnsUpstreamQuery,
307    ) {
308        if let None =
309            self.transactions_map.lock().remove(&DnsUpstreamQueryRefWrapper(thread_context))
310        {
311            warn!(
312                tag = "resolver";
313                "on_cancel_dns_upstream_query() target transaction not presented for remove, ignoring"
314            );
315        }
316    }
317}
318
319#[unsafe(no_mangle)]
320unsafe extern "C" fn otPlatDnsStartUpstreamQuery(
321    a_instance: *mut otInstance,
322    a_txn: *mut otPlatDnsUpstreamQuery,
323    a_query: *const otMessage,
324) {
325    Resolver::on_start_dns_upstream_query(
326        &unsafe { PlatformBacking::as_ref() }.resolver,
327        // SAFETY: `instance` must be a pointer to a valid `otInstance`,
328        //         which is guaranteed by the caller.
329        unsafe { ot::Instance::ref_from_ot_ptr(a_instance) }.unwrap(),
330        // SAFETY: no dereference is happening in fuchsia platform side
331        unsafe { ot::PlatDnsUpstreamQuery::mut_from_ot_mut_ptr(a_txn) }.unwrap(),
332        // SAFETY: caller ensures the dns query is valid
333        unsafe { ot::Message::ref_from_ot_ptr(a_query as *mut otMessage) }.unwrap(),
334    )
335}
336
337#[unsafe(no_mangle)]
338unsafe extern "C" fn otPlatDnsIsUpstreamQueryAvailable(a_instance: *mut otInstance) -> bool {
339    // The instance parameter is part of the function signature but unused in this implementation.
340    // We still get a reference to it to match the pattern of other platform functions.
341    let _ = unsafe { ot::Instance::ref_from_ot_ptr(a_instance) };
342    unsafe { PlatformBacking::as_ref() }.resolver.is_upstream_query_available()
343}
344
345#[unsafe(no_mangle)]
346unsafe extern "C" fn otPlatDnsCancelUpstreamQuery(
347    a_instance: *mut otInstance,
348    a_txn: *mut otPlatDnsUpstreamQuery,
349) {
350    Resolver::on_cancel_dns_upstream_query(
351        &unsafe { PlatformBacking::as_ref() }.resolver,
352        // SAFETY: `instance` must be a pointer to a valid `otInstance`,
353        //         which is guaranteed by the caller.
354        unsafe { ot::Instance::ref_from_ot_ptr(a_instance) }.unwrap(),
355        // SAFETY: no dereference is happening in fuchsia platform side
356        unsafe { ot::PlatDnsUpstreamQuery::mut_from_ot_mut_ptr(a_txn) }.unwrap(),
357    )
358}